Skip to main content

Gradient Boosted Trees for RL

Reason this release was yanked:

buggy

Project description

Gradient Boosting Reinforcement Learning (GBRL)

GBRL is a Python-based Gradient Boosting Trees (GBT) library, similar to popular packages such as XGBoost, CatBoost, but specifically designed and optimized for reinforcement learning (RL). GBRL is implemented in C++/CUDA aimed to seamlessly integrate within popular RL libraries.

License PyPI version

Overview

GBRL adapts the power of Gradient Boosting Trees to the unique challenges of RL environments, including non-stationarity and the absence of predefined targets. The following diagram illustrates how GBRL uses gradient boosting trees in RL:

GBRL Diagram

GBRL features a shared tree-based structure for policy and value functions, significantly reducing memory and computational overhead, enabling it to tackle complex, high-dimensional RL problems.

Key Features:

  • GBT Tailored for RL: GBRL adapts the power of Gradient Boosting Trees to the unique challenges of RL environments, including non-stationarity and the absence of predefined targets.
  • Optimized Actor-Critic Architecture: GBRL features a shared tree-based structure for policy and value functions. This significantly reduces memory and computational overhead, enabling it to tackle complex, high-dimensional RL problems.
  • Hardware Acceleration: GBRL leverages CUDA for hardware-accelerated computation, ensuring efficiency and speed.
  • Seamless Integration: GBRL is designed for easy integration with popular RL libraries. We implemented GBT-based actor-critic algorithm implementations (A2C, PPO, and AWR) in stable_baselines3 GBRL_SB3.

Performance

The following results, obtained using the GBRL_SB3 repository, demonstrate the performance of PPO with GBRL compared to neural-networks across various scenarios and environments:

PPO GBRL results in stable_baselines3

Getting started

Prerequisites

  • Python 3.9 or higher
  • LLVM and OpenMP (macOS).

Installation

To install GBRL via pip, use the following command:

pip install gbrl

For further installation details and dependencies see the documentation.

Usage Example

For a detailed usage example, see tutorial.ipynb

Current Supported Features

Tree Fitting

  • Greedy (Depth-wise) tree building - (CPU/GPU)
  • Oblivious (Symmetric) tree building - (CPU/GPU)
  • L2 split score - (CPU/GPU)
  • Cosine split score - (CPU/GPU)
  • Uniform based candidate generation - (CPU/GPU)
  • Quantile based candidate generation - (CPU/GPU)
  • Supervised learning fitting / Multi-iteration fitting - (CPU/GPU)
    • MultiRMSE loss (only)
  • Categorical inputs
  • Input feature weights - (CPU/GPU)

GBT Inference

  • SGD optimizer - (CPU/GPU)
  • ADAM optimizer - (CPU only)
  • Control Variates (gradient variance reduction technique) - (CPU only)
  • Shared Tree for policy and value function - (CPU/GPU)
  • Linear and constant learning rate scheduler - (CPU/GPU only constant)
  • Support for up to two different optimizers (e.g, policy/value) - **(CPU/GPU if both are SGD)
  • SHAP value calculation

Documentation

For comprehensive documentation, visit the GBRL documentation.

Citation

@article{gbrl,
  title={Gradient Boosting Reinforcement Learning},
  author={Benjamin Fuhrer, Chen Tessler, Gal Dalal},
  year={2024},
  eprint={2407.08250},
  archivePrefix={arXiv},
  primaryClass={cs.LG},
  url={https://arxiv.org/abs/2407.08250}, 
}

Licenses

Copyright © 2024, NVIDIA Corporation. All rights reserved.

This work is made available under the NVIDIA Source Code License-NC. Click here. to view a copy of this license.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gbrl-1.0.9.tar.gz (107.9 kB view details)

Uploaded Source

File details

Details for the file gbrl-1.0.9.tar.gz.

File metadata

  • Download URL: gbrl-1.0.9.tar.gz
  • Upload date:
  • Size: 107.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.6

File hashes

Hashes for gbrl-1.0.9.tar.gz
Algorithm Hash digest
SHA256 69102f42a2ab69d2060df96e459ea81ec54ae82a4e9d296a397c5d5b7c6341a6
MD5 d13ee171b1443cfe3ff53cea21627971
BLAKE2b-256 f077e0db34b6bd0a2d057f25424bebfbce6506d865f00fdf0eec52e0c5e58220

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page