Skip to main content

Extra buffer classes for Stable-Baselines3, reduce memory usage with minimal overhead.

Project description

PyPI - Version Pepy Total Downloads PyPI - License PyPI - Implementation

sb3-extra-buffers

Unofficial implementation of extra Stable-Baselines3 buffer classes. Aims to reduce memory usage drastically with minimal overhead.

SB3-Extra-Buffers Banner

Links:

Main Goal: Reduce the memory consumption of memory buffers in Reinforcement Learning while adding minimal overhead.

Current Progress & Available Features: Memory Saving: reported here See Issue https://github.com/Trenza1ore/sb3-extra-buffers/issues/1

Motivation: Reinforcement Learning is quite memory-hungry due to massive buffer sizes, so let's try to tackle it by not storing raw frame buffers in full np.float32 or np.uint8 directly and find something smaller instead. For any input data that are sparse and containing large contiguous region of repeating values, lossless compression techniques can be applied to reduce memory footprint.

Applicable Input Types:

  • Semantic Segmentation masks (1 color channel)
  • Color Palette game frames from retro video games
  • Grayscale observations
  • RGB (Color) observations
  • For noisy input with a lot of variation (mostly RGB), using gzip1 or igzip0 is recommended, run-length encoding won't work as great and can potentially even increase memory usage.

Implemented Compression Methods:

  • rle Vectorized Run-Length Encoding for compression.
  • rle-jit JIT-compiled version of rle, uses numba library.
  • gzip Gzip compression via gzip.
  • igzip Intel accelerated variant via isal.igzip, uses python-isal library.
  • none No compression other than casting to elem_type and storing as bytes.
  • gzip supports 0-9 compress levels, 0 is no compression, 1 is least compression
  • igzip supports 0-3 compress levels, 0 is least compression
  • Shorthands are supported, i.e. igzip3 = igzip at level 3

Installation

Install via PyPI:

pip install "sb3-extra-buffers[fast,extra]"

Other install options:

pip install "sb3-extra-buffers"          # only installs minimum requirements
pip install "sb3-extra-buffers[extra]"   # installs extra dependencies for SB3
pip install "sb3-extra-buffers[fast]"    # installs python-isal and numba
pip install "sb3-extra-buffers[isal]"    # only installs python-isal
pip install "sb3-extra-buffers[numba]"   # only installs numba
pip install "sb3-extra-buffers[vizdoom]" # installs vizdoom

Example Usage

from stable_baselines3 import PPO
from stable_baselines3.common.utils import get_linear_fn
from stable_baselines3.common.callbacks import EvalCallback
from sb3_extra_buffers.compressed import CompressedRolloutBuffer, find_buffer_dtypes
from sb3_extra_buffers.training_utils.atari import make_env

ATARI_GAME = "MsPacmanNoFrameskip-v4"

if __name__ == "__main__":
    # Get the most suitable dtypes for CompressedRolloutBuffer to use
    obs = make_env(env_id=ATARI_GAME, n_envs=1, framestack=4).observation_space
    compression = "rle-jit"  # or use "igzip1" since it's relatively noisy
    buffer_dtypes = find_buffer_dtypes(obs_shape=obs.shape, elem_dtype=obs.dtype, compression_method=compression)

    # Create vectorized environments after the find_buffer_dtypes call, which initializes jit
    env = make_env(env_id=ATARI_GAME, n_envs=8, framestack=4)
    eval_env = make_env(env_id=ATARI_GAME, n_envs=10, framestack=4)

    # Create PPO model with CompressedRolloutBuffer as rollout buffer class
    model = PPO("CnnPolicy", env, verbose=1, learning_rate=get_linear_fn(2.5e-4, 0, 1), n_steps=128,
                batch_size=256, clip_range=get_linear_fn(0.1, 0, 1), n_epochs=4, ent_coef=0.01, vf_coef=0.5,
                seed=1970626835, device="mps", rollout_buffer_class=CompressedRolloutBuffer,
                rollout_buffer_kwargs=dict(dtypes=buffer_dtypes, compression_method=compression))

    # Evaluation callback (optional)
    eval_callback = EvalCallback(eval_env, n_eval_episodes=20, eval_freq=8192, log_path=f"./logs/{ATARI_GAME}/ppo/eval",
                                 best_model_save_path=f"./logs/{ATARI_GAME}/ppo/best_model")

    # Training
    model.learn(total_timesteps=10_000_000, callback=eval_callback, progress_bar=True)

    # Save the final model
    model.save("ppo_MsPacman_4.zip")

    # Cleanup
    env.close()
    eval_env.close()

Current Project Structure

sb3_extra_buffers
    |- compressed
    |    |- CompressedRolloutBuffer: RolloutBuffer with compression
    |    |- CompressedReplayBuffer: ReplayBuffer with compression
    |    |- CompressedArray: Compressed numpy.ndarray subclass
    |    |- find_buffer_dtypes: Find suitable buffer dtypes and initialize jit
    |
    |- recording
    |    |- RecordBuffer: A buffer for recording game states
    |    |- FramelessRecordBuffer: RecordBuffer but not recording game frames
    |    |- DummyRecordBuffer: Dummy RecordBuffer, records nothing
    |
    |- training_utils
         |- eval_model: Evaluate models in vectorized environment
         |- warmup: Perform buffer warmup for off-policy algorithms

Example Scripts

Example scripts have been included and tested to ensure working properly.

Evaluation results for example training scripts:

PPO on PongNoFrameskip-v4, trained for 10M steps using rle-jit, framestack: None

(Best ) Evaluated 10000 episodes, mean reward: 21.0 +/- 0.00
Q1:   21 | Q2:   21 | Q3:   21 | Relative IQR: 0.00 | Min: 21 | Max: 21
(Final) Evaluated 10000 episodes, mean reward: 21.0 +/- 0.02
Q1:   21 | Q2:   21 | Q3:   21 | Relative IQR: 0.00 | Min: 20 | Max: 21

PPO on MsPacmanNoFrameskip-v4, trained for 10M steps using rle-jit, framestack: 4

(Best ) Evaluated 10000 episodes, mean reward: 2667.0 +/- 290.00
Q1: 2300 | Q2: 2490 | Q3: 3000 | Relative IQR: 0.28 | Min: 2300 | Max: 3000
(Final) Evaluated 10000 episodes, mean reward: 2500.9 +/- 221.03
Q1: 2300 | Q2: 2390 | Q3: 2490 | Relative IQR: 0.08 | Min: 1420 | Max: 3000

DQN on MsPacmanNoFrameskip-v4, trained for 10M steps using rle-jit, framestack: 4

(Best ) Evaluated 10000 episodes, mean reward: 3300.0 +/- 770.79
Q1: 2490 | Q2: 4020 | Q3: 4020 | Relative IQR: 0.38 | Min: 2460 | Max: 4020
(Final) Evaluated 10000 episodes, mean reward: 3379.2 +/- 453.78
Q1: 2690 | Q2: 3400 | Q3: 3880 | Relative IQR: 0.35 | Min: 1230 | Max: 4090

Pytest

Make sure pytest and optionally pytest-xdist are already installed. Tests are compatible with pytest-xdist since DummyVecEnv is used for all tests.

# pytest
pytest tests -v --durations=0 --tb=short
# pytest-xdist
pytest tests -n auto -v --durations=0 --tb=short

Compressed Buffers

Defined in sb3_extra_buffers.compressed

JIT Before Multi-Processing: When using rle-jit, remember to trigger JIT compilation before any multi-processing code is executed via find_buffer_dtypes or init_jit.

# Code for other stuffs...

# Get observation space from environment
obs = make_env(env_id=ATARI_GAME, n_envs=1, framestack=4).observation_space

# Get the buffer datatype settings via find_buffer_dtypes
compression = "rle-jit"
buffer_dtypes = find_buffer_dtypes(obs_shape=obs.shape, elem_dtype=obs.dtype, compression_method=compression)

# Now, safe to initialize multi-processing environments!
env = SubprocVecEnv(...)

Memory Usage Test for Compressed Buffers (on MsPacmanNoFrameskip-v4)

  • Frame Stack & Vec Envs: both 4
  • Buffer Size: 400,000 (split across 4 vectorized environments)
  • Using optimize_memory_usage: True
  • Steps Per Env: 121,558 (Total Observations = 486,232, but truncated to 400,000)
  • Evaluation Time: 00:17:33 on M4 Macbook Air, using mps backend (out of which more than 10 minutes were spent on gzip alone?!).
  • Settings: The example DQN model is loaded and evaluated using the code in examples/example_eval_memory_saving.py. The exact same observations are stored into each buffer. Latency refers to the total number of seconds spent on adding observation to the specific buffer and baseline refers to using ReplayBuffer directly.
Compression Memory (MB) Memory % Latency (s)
baseline 10767 100.0% 3.708
rle-jit 2048 19.0% 14.621
igzip0 576 5.3% 12.372
igzip1 490 4.5% 21.890
igzip2 489 4.5% 20.515
gzip1 480 4.5% 41.989
gzip3 439 4.1% 46.568
igzip3 432 4.0% 35.700
gzip5 386 3.6% 64.545
gzip7 372 3.5% 117.228
gzip9 369 3.4% 354.114

Recording Buffers

Defined in sb3_extra_buffers.recording Mainly used in combination with SegDoom to record stuff.

WIP


Training Utils

Defined in sb3_extra_buffers.training_utils Buffer warm-up and model evaluation

WIP

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sb3_extra_buffers-0.3.2.tar.gz (26.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

sb3_extra_buffers-0.3.2-py3-none-any.whl (35.6 kB view details)

Uploaded Python 3

File details

Details for the file sb3_extra_buffers-0.3.2.tar.gz.

File metadata

  • Download URL: sb3_extra_buffers-0.3.2.tar.gz
  • Upload date:
  • Size: 26.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.2

File hashes

Hashes for sb3_extra_buffers-0.3.2.tar.gz
Algorithm Hash digest
SHA256 bc67bac0d3bc84099d4795f30d2fdb1d714154cc6262ad114c28ce17de0259d3
MD5 ef46ebf77243808836ea7419f7cd4434
BLAKE2b-256 ee7bf6a74effa0b377df1556412d09e5d4264dd48790c5c9f5a98bcc810a0a62

See more details on using hashes here.

File details

Details for the file sb3_extra_buffers-0.3.2-py3-none-any.whl.

File metadata

File hashes

Hashes for sb3_extra_buffers-0.3.2-py3-none-any.whl
Algorithm Hash digest
SHA256 05a5798603d9754124715fb35fd7400cb63aedc6b2aa79652dd31ee9f2ee799c
MD5 8d90fd8ebbfdf1c9f48cdc41405b579e
BLAKE2b-256 2b35e38c1df50368c563f760df5e1d22bfd7e559243fffcd5d6d60c189ccf39a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page