Skip to main content

Extra buffer classes for Stable-Baselines3

Reason this release was yanked:

Numba JIT Caching issue

Project description

sb3-extra-buffers

Unofficial implementation of extra Stable-Baselines3 buffer classes, mostly a proof-of-concept in current state. Main Goal: Reduce the memory consumption of memory buffers in Reinforcement Learning.

Motivation: Reinforcement Learning is quite memory-hungry due to massive buffer sizes, so let's try to tackle it by not storing raw frame buffers in full np.float32 directly and find something smaller instead. For any input data that are sparse and containing large contiguous region of repeating values, lossless compression techniques can be applied to reduce memory footprint.

Applicable Input Types:

  • Semantic Segmentation masks (1 color channel)
  • Color Palette game frames from retro video games
  • Grayscale game frames from retro video games

Installation

To install with isal and numba support:

pip install "sb3_extra_buffers[fast]"

Other install options:

pip install sb3_extra_buffers            # only installs minimum requirements
pip install "sb3_extra_buffers[isal]"    # only installs python-isal
pip install "sb3_extra_buffers[numba]"   # only installs numba
pip install "sb3_extra_buffers[atari]"   # installs gymnasium, ale-py
pip install "sb3_extra_buffers[vizdoom]" # installs gymnasium, vizdoom

Project Structure

sb3_extra_buffers
    |- compressed
    |    |- CompressedRolloutBuffer: RolloutBuffer with compression
    |    |- CompressedReplayBuffer: ReplayBuffer with compression
    |
    |- recording
         |- RecordBuffer: A buffer for recording game states
         |- FramelessRecordBuffer: RecordBuffer but not recording game frames
         |- DummyRecordBuffer: Dummy RecordBuffer, records nothing

Compressed Buffers

Defined in sb3_extra_buffers.compressed

Implemented Compression Methods:

  • rle Uses Run-Length Encoding for compression.
  • rle-jit JIT-compiled version of rle, uses numba library.
  • gzip Compression via gzip.
  • igzip Compression via isal.igzip, uses python-isal library.
  • none No compression other than casting to elem_type.

JIT Before Multi-Processing: When using rle-jit, remember to trigger JIT compilation before any multi-processing code is executed.

# Code for other stuffs...
from sb3_extra_buffers.compressed.compression_methods import HAS_NUMBA

# Compressed-buffer-related settings
compression_method = "rle-jit"
storage_dtypes = dict(elem_type=np.uint8, runs_type=np.uint16)

# Pre-JIT Numba to avoid fork issues
if HAS_NUMBA and "jit" in compression_method:
    from sb3_extra_buffers.compressed.compression_methods.compression_methods_numba import init_jit
    init_jit(**storage_dtypes)

# Now, safe to initialize multi-processing environments!
env = SubprocVecEnv([make_env for _ in range(4)])

Example Usage:

import numpy as np
import gymnasium as gym
from stable_baselines3 import PPO
from sb3_extra_buffers.compressed import CompressedRolloutBuffer, find_smallest_dtype

env = gym.make("CartPole-v1", render_mode="human")
flatten_obs_shape = np.prod(env.observation_space.shape)
buffer_dtypes = dict(elem_type=np.uint8, runs_type=find_smallest_dtype(flatten_obs_shape))

model = PPO("MlpPolicy", env, verbose=1, rollout_buffer_class=CompressedRolloutBuffer,
            rollout_buffer_kwargs=dict(dtypes=buffer_dtypes, compression_method="rle"))
model.learn(total_timesteps=10_000)

vec_env = model.get_env()
obs = vec_env.reset()
for i in range(1000):
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, done, info = vec_env.step(action)
    vec_env.render()

env.close()

Recording Buffers

Defined in sb3_extra_buffers.recording Mainly used in combination with SegDoom to record stuffs.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sb3_extra_buffers-0.1.6.tar.gz (12.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

sb3_extra_buffers-0.1.6-py3-none-any.whl (13.3 kB view details)

Uploaded Python 3

File details

Details for the file sb3_extra_buffers-0.1.6.tar.gz.

File metadata

  • Download URL: sb3_extra_buffers-0.1.6.tar.gz
  • Upload date:
  • Size: 12.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.3

File hashes

Hashes for sb3_extra_buffers-0.1.6.tar.gz
Algorithm Hash digest
SHA256 395bc332014e80707f46b9897985d2b7a841048eafcb2321f5bc8e6fac8e5701
MD5 ee843995cceeecdc1c473b7164e27c22
BLAKE2b-256 888c4b4aa3bab64f1b5e10a5b75132e437204a72797a4cda3f16728df13aa710

See more details on using hashes here.

File details

Details for the file sb3_extra_buffers-0.1.6-py3-none-any.whl.

File metadata

File hashes

Hashes for sb3_extra_buffers-0.1.6-py3-none-any.whl
Algorithm Hash digest
SHA256 0b15a34886e0bd17a12aba4c043384adde81d3b76bfd15bdcbc16f0fa31dd39c
MD5 67ad69cfb093f8f38c974efb81b1ea30
BLAKE2b-256 253a1ca1011d6447a68d3d706d9fd387499fe514438922a859c50cace7bb0b08

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page