Skip to main content

Extra buffer classes for Stable-Baselines3, reduce memory usage with minimal overhead.

Project description

PyPI - Version Pepy Total Downloads PyPI - License PyPI - Implementation

sb3-extra-buffers

Unofficial implementation of extra Stable-Baselines3 buffer classes. Aims to reduce memory usage drastically with minimal overhead. Featured in SB3 docs :-)

Banner Image

Links:

Description: Tired of reading a cool RL paper and realizing that the author is storing a MILLION observations in their replay buffers? Yeah me too. This project has implemented several compressed buffer classes that replace Stable Baselines3's standard buffers like ReplayBuffer and RolloutBuffer. With as simple as 2-5 lines of extra code and negligible overhead, memory usage can be reduced by more than 95%!

Main Goal: Reduce the memory consumption of memory buffers in Reinforcement Learning while adding minimal overhead.

Current Progress & Available Features:

Motivation: Reinforcement Learning is quite memory-hungry due to massive buffer sizes, so let's try to tackle it by not storing raw frame buffers in full np.float32 or np.uint8 directly and find something smaller instead. For any input data that are sparse and containing large contiguous region of repeating values, lossless compression techniques can be applied to reduce memory footprint.

Applicable Input Types:

  • Semantic Segmentation masks (1 color channel)
  • Color Palette game frames from retro video games
  • Grayscale observations
  • RGB (Color) observations
  • For noisy input with a lot of variation (mostly RGB), using zstd is recommended, run-length encoding won't work as great and can potentially even increase memory usage. See benchmark.

Implemented Compression Methods:

  • none No compression other than casting to elem_type and storing as bytes.
  • rle Vectorized Run-Length Encoding for compression.
  • rle-jit JIT-compiled version of rle, uses numba library.
  • gzip Built-in gzip compression via gzip.
  • igzip Intel accelerated variant via isal.igzip, uses python-isal library.
  • zstd Zstandard compression via python-zstd. (Recommended)
  • lz4-frame LZ4 (frame format) compression via python-lz4.
  • lz4-block LZ4 (block format) compression via python-lz4.
  • gzip supports 0~9 compression levels, 0 is no compression, 1 is least compression
  • igzip supports 0~3 compression levels, 0 is least compression
  • zstd supports 1~22 standard compression levels and -100~-1 ultra-fast compression levels, -100 is fastest and 22 is slowest.
  • lz4-frame supports 0~16 standard compression levels and negative levels translates into acceleration factor.
  • lz4-block supports three modes, split into positive/zero/negative compression levels. 1~12 are in high_compression mode and negative levels translates into acceleration factor in fast mode, setting 0 enables default mode.
  • Shorthands are supported (for lz4 methods including / is required):
    • pattern = ^((?:[A-Za-z]+)|(?:[\w\-]+/))(\-?[0-9]+)$
    • igzip3 = igzip/3 = igzip level 3
    • zstd-5 = zstd/-5 = zstd level -5
    • lz4-frame/5 = lz4-frame level 5

Installation

Install via PyPI:

pip install "sb3-extra-buffers[fast,extra]"

Other install options:

pip install "sb3-extra-buffers"          # only installs minimum requirements
pip install "sb3-extra-buffers[extra]"   # installs extra dependencies for SB3
pip install "sb3-extra-buffers[fast]"    # installs python-isal, numba, zstd, lz4
pip install "sb3-extra-buffers[isal]"    # only installs python-isal
pip install "sb3-extra-buffers[numba]"   # only installs numba
pip install "sb3-extra-buffers[zstd]"    # only installs python-zstd
pip install "sb3-extra-buffers[lz4]"     # only installs python-lz4
pip install "sb3-extra-buffers[vizdoom]" # installs vizdoom

Example Usage

from stable_baselines3 import PPO
from stable_baselines3.common.utils import get_linear_fn
from stable_baselines3.common.callbacks import EvalCallback
from sb3_extra_buffers.compressed import CompressedRolloutBuffer, find_buffer_dtypes
from sb3_extra_buffers.training_utils.atari import make_env

ATARI_GAME = "MsPacmanNoFrameskip-v4"

if __name__ == "__main__":
    # Get the most suitable dtypes for CompressedRolloutBuffer to use
    obs = make_env(env_id=ATARI_GAME, n_envs=1, framestack=4).observation_space
    compression = "rle-jit"  # or use "igzip1" since it's relatively noisy
    buffer_dtypes = find_buffer_dtypes(obs_shape=obs.shape, elem_dtype=obs.dtype, compression_method=compression)

    # Create vectorized environments after the find_buffer_dtypes call, which initializes jit
    env = make_env(env_id=ATARI_GAME, n_envs=8, framestack=4)
    eval_env = make_env(env_id=ATARI_GAME, n_envs=10, framestack=4)

    # Create PPO model with CompressedRolloutBuffer as rollout buffer class
    model = PPO("CnnPolicy", env, verbose=1, learning_rate=get_linear_fn(2.5e-4, 0, 1), n_steps=128,
                batch_size=256, clip_range=get_linear_fn(0.1, 0, 1), n_epochs=4, ent_coef=0.01, vf_coef=0.5,
                seed=1970626835, device="mps", rollout_buffer_class=CompressedRolloutBuffer,
                rollout_buffer_kwargs=dict(dtypes=buffer_dtypes, compression_method=compression))

    # Evaluation callback (optional)
    eval_callback = EvalCallback(eval_env, n_eval_episodes=20, eval_freq=8192, log_path=f"./logs/{ATARI_GAME}/ppo/eval",
                                 best_model_save_path=f"./logs/{ATARI_GAME}/ppo/best_model")

    # Training
    model.learn(total_timesteps=10_000_000, callback=eval_callback, progress_bar=True)

    # Save the final model
    model.save("ppo_MsPacman_4.zip")

    # Cleanup
    env.close()
    eval_env.close()

Current Project Structure

sb3_extra_buffers
    |- compressed
    |    |- CompressedRolloutBuffer: RolloutBuffer with compression
    |    |- CompressedReplayBuffer: ReplayBuffer with compression
    |    |- CompressedArray: Compressed numpy.ndarray subclass
    |    |- find_buffer_dtypes: Find suitable buffer dtypes and initialize jit
    |
    |- recording
    |    |- RecordBuffer: A buffer for recording game states
    |    |- FramelessRecordBuffer: RecordBuffer but not recording game frames
    |    |- DummyRecordBuffer: Dummy RecordBuffer, records nothing
    |
    |- training_utils
         |- eval_model: Evaluate models in vectorized environment
         |- warmup: Perform buffer warmup for off-policy algorithms

Example Scripts

Example scripts have been included and tested to ensure working properly.

Evaluation results for example training scripts:

PPO on PongNoFrameskip-v4, trained for 10M steps using rle-jit, framestack: None

(Best ) Evaluated 10000 episodes, mean reward: 21.0 +/- 0.00
Q1:   21 | Q2:   21 | Q3:   21 | Relative IQR: 0.00 | Min: 21 | Max: 21
(Final) Evaluated 10000 episodes, mean reward: 21.0 +/- 0.02
Q1:   21 | Q2:   21 | Q3:   21 | Relative IQR: 0.00 | Min: 20 | Max: 21

PPO on MsPacmanNoFrameskip-v4, trained for 10M steps using rle-jit, framestack: 4

(Best ) Evaluated 10000 episodes, mean reward: 2667.0 +/- 290.00
Q1: 2300 | Q2: 2490 | Q3: 3000 | Relative IQR: 0.28 | Min: 2300 | Max: 3000
(Final) Evaluated 10000 episodes, mean reward: 2500.9 +/- 221.03
Q1: 2300 | Q2: 2390 | Q3: 2490 | Relative IQR: 0.08 | Min: 1420 | Max: 3000

DQN on MsPacmanNoFrameskip-v4, trained for 10M steps using rle-jit, framestack: 4

(Best ) Evaluated 10000 episodes, mean reward: 3300.0 +/- 770.79
Q1: 2490 | Q2: 4020 | Q3: 4020 | Relative IQR: 0.38 | Min: 2460 | Max: 4020
(Final) Evaluated 10000 episodes, mean reward: 3379.2 +/- 453.78
Q1: 2690 | Q2: 3400 | Q3: 3880 | Relative IQR: 0.35 | Min: 1230 | Max: 4090

Pytest

Make sure pytest and optionally pytest-xdist are already installed. Tests are compatible with pytest-xdist since DummyVecEnv is used for all tests.

# pytest
pytest tests -v --durations=0 --tb=short
# pytest-xdist
pytest tests -n auto -v --durations=0 --tb=short

Compressed Buffers

Defined in sb3_extra_buffers.compressed

JIT Before Multi-Processing: When using rle-jit, remember to trigger JIT compilation before any multi-processing code is executed via find_buffer_dtypes or init_jit.

# Code for other stuffs...

# Get observation space from environment
obs = make_env(env_id=ATARI_GAME, n_envs=1, framestack=4).observation_space

# Get the buffer datatype settings via find_buffer_dtypes
compression = "rle-jit"
buffer_dtypes = find_buffer_dtypes(obs_shape=obs.shape, elem_dtype=obs.dtype, compression_method=compression)

# Now, safe to initialize multi-processing environments!
env = SubprocVecEnv(...)

Memory Usage Test for Compressed Buffers (on MsPacmanNoFrameskip-v4)

  • Frame Stack & Vec Envs: both 4
  • Buffer Size: 40,000 (split across 4 vectorized environments)
  • Using optimize_memory_usage: True
  • Steps Per Env: 12,475 steps for each env (49,900 steps in total, but truncated to 40,000)
  • Evaluation Time: 0:16:21 on an M4 Macbook Air, using mps backend (out of which almost 10 minutes were spent on zstd22 alone?!).
  • Settings: The example DQN model is loaded and evaluated using the code in examples/example_eval_memory_saving.py. The exact same observations are stored into each buffer. Latency refers to the total number of seconds spent on adding observation to the specific buffer and baseline refers to using ReplayBuffer directly.
  • TLDR:
    • zstd in general is very decent at save latency & memory saving
    • zstd-1 ~ zstd-5 seems to be the sweet spot
    • gzip0 should be avoided
    • MsPacman at 84x84 resolution is too visually noisy for rle
Compression Memory Memory % Latency
baseline 1.05GB 100.0% 0.9s
none 1.05GB 100.1% 1.2s
zstd-100 387MB 36.0% 1.8s
zstd-50 306MB 28.4% 1.9s
lz4-frame/1 118MB 10.9% 2.1s
gzip0 1.05GB 100.2% 2.1s
zstd-5 82.9MB 7.7% 2.1s
zstd-20 181MB 16.8% 2.2s
zstd-3 73.9MB 6.9% 2.3s
zstd-1 66.0MB 6.1% 2.3s
zstd1 61.3MB 5.7% 2.7s
zstd3 59.4MB 5.5% 3.0s
igzip0 129MB 12.0% 3.4s
rle 811MB 75.3% 4.0s
lz4-block/1 83.2MB 7.7% 4.6s
igzip1 114MB 10.6% 5.0s
zstd5 55.9MB 5.2% 5.4s
lz4-block/5 75.1MB 7.0% 6.3s
lz4-frame/5 75.9MB 7.0% 6.5s
gzip1 104MB 9.6% 7.6s
gzip3 81.9MB 7.6% 8.3s
igzip3 81.5MB 7.6% 10.5s
zstd10 52.8MB 4.9% 10.8s
lz4-block/9 72.0MB 6.7% 20.0s
lz4-frame/9 72.7MB 6.8% 20.0s
lz4-block/16 71.3MB 6.6% 57.9s
lz4-frame/12 72.0MB 6.7% 58.4s
zstd15 48.5MB 4.5% 99.8s
zstd22 47.6MB 4.4% 590.7s

Recording Buffers

Defined in sb3_extra_buffers.recording Mainly used in combination with SegDoom to record stuff.

WIP


Training Utils

Defined in sb3_extra_buffers.training_utils Buffer warm-up and model evaluation

WIP

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sb3_extra_buffers-0.4.0.tar.gz (30.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

sb3_extra_buffers-0.4.0-py3-none-any.whl (39.3 kB view details)

Uploaded Python 3

File details

Details for the file sb3_extra_buffers-0.4.0.tar.gz.

File metadata

  • Download URL: sb3_extra_buffers-0.4.0.tar.gz
  • Upload date:
  • Size: 30.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.2

File hashes

Hashes for sb3_extra_buffers-0.4.0.tar.gz
Algorithm Hash digest
SHA256 4ee215aa6a8d44f4fc6a2e2c1f1e09d299ec55374fe4f3d1f1a241d6b65dbca8
MD5 6d4a7dd860574615382af8fe12994122
BLAKE2b-256 2e1a5e47ac332d873a981d97fafad2c8feb46acc2cfb635fea56447d93124afe

See more details on using hashes here.

File details

Details for the file sb3_extra_buffers-0.4.0-py3-none-any.whl.

File metadata

File hashes

Hashes for sb3_extra_buffers-0.4.0-py3-none-any.whl
Algorithm Hash digest
SHA256 159c0b558a1ce4bd958a7250fd4f18fb6929826cc71944f7d6b5b106ee6d75d3
MD5 a2af99d1d76f51411155d0d0eb6a7805
BLAKE2b-256 a3a47ba94c6edc5f51f06587992ee0dc951ae2622cbb8b73d5acdb4623227d53

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page