Gradient Boosted Trees for RL
Project description
Gradient Boosting Reinforcement Learning (GBRL)
GBRL is a Python-based Gradient Boosting Trees (GBT) library, similar to popular packages such as XGBoost, CatBoost, but specifically designed and optimized for reinforcement learning (RL). GBRL is implemented in C++/CUDA aimed to seamlessly integrate within popular RL libraries.
Overview
GBRL adapts the power of Gradient Boosting Trees to the unique challenges of RL environments, including non-stationarity and the absence of predefined targets. The following diagram illustrates how GBRL uses gradient boosting trees in RL:
GBRL features a shared tree-based structure for policy and value functions, significantly reducing memory and computational overhead, enabling it to tackle complex, high-dimensional RL problems.
Key Features:
- GBT Tailored for RL: GBRL adapts the power of Gradient Boosting Trees to the unique challenges of RL environments, including non-stationarity and the absence of predefined targets.
- Optimized Actor-Critic Architecture: GBRL features a shared tree-based structure for policy and value functions. This significantly reduces memory and computational overhead, enabling it to tackle complex, high-dimensional RL problems.
- Hardware Acceleration: GBRL leverages CUDA for hardware-accelerated computation, ensuring efficiency and speed.
- Seamless Integration: GBRL is designed for easy integration with popular RL libraries. We implemented GBT-based actor-critic algorithm implementations (A2C, PPO, and AWR) in stable_baselines3 GBRL_SB3.
Performance
The following results, obtained using the GBRL_SB3
repository, demonstrate the performance of PPO with GBRL compared to neural-networks across various scenarios and environments:
Getting started
Prerequisites
- Python 3.9 or higher
- LLVM and OpenMP (macOS).
Installation
To install GBRL via pip, use the following command:
pip install gbrl
For further installation details and dependencies see the documentation.
Usage Example
For a detailed usage example, see tutorial.ipynb
Current Supported Features
Tree Fitting
- Greedy (Depth-wise) tree building - (CPU/GPU)
- Oblivious (Symmetric) tree building - (CPU/GPU)
- L2 split score - (CPU/GPU)
- Cosine split score - (CPU/GPU)
- Uniform based candidate generation - (CPU/GPU)
- Quantile based candidate generation - (CPU/GPU)
- Supervised learning fitting / Multi-iteration fitting - (CPU/GPU)
- MultiRMSE loss (only)
- Categorical inputs
- Input feature weights - (CPU/GPU)
GBT Inference
- SGD optimizer - (CPU/GPU)
- ADAM optimizer - (CPU only)
- Control Variates (gradient variance reduction technique) - (CPU only)
- Shared Tree for policy and value function - (CPU/GPU)
- Linear and constant learning rate scheduler - (CPU/GPU only constant)
- Support for up to two different optimizers (e.g, policy/value) - **(CPU/GPU if both are SGD)
- SHAP value calculation
Documentation
For comprehensive documentation, visit the GBRL documentation.
Citation
@article{gbrl,
title={Gradient Boosting Reinforcement Learning},
author={Benjamin Fuhrer, Chen Tessler, Gal Dalal},
year={2024},
eprint={2407.08250},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2407.08250},
}
Licenses
Copyright © 2024, NVIDIA Corporation. All rights reserved.
This work is made available under the NVIDIA Source Code License-NC. Click here. to view a copy of this license.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file gbrl-1.0.6.tar.gz
.
File metadata
- Download URL: gbrl-1.0.6.tar.gz
- Upload date:
- Size: 102.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.11.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 44dfe6b20a0ed4615d064dd44888ee15afbe6adf3bd2a60340eb234bb2b02b9c |
|
MD5 | 0ebf0d4908993b5d1f2217518e6833ee |
|
BLAKE2b-256 | 0adceb1a740659c06b860b02bfd0753599a28dc3824dc5fb788954f2105cc8c5 |