PyTorch implementations of generative reinforcement learning algorithms
Project description
Generative Reinforcement Learning (GRL)
English | 简体中文(Simplified Chinese)
GenerativeRL, short for Generative Reinforcement Learning, is a Python library for solving reinforcement learning (RL) problems using generative models, such as diffusion models and flow models. This library aims to provide a framework for combining the power of generative models with the decision-making capabilities of reinforcement learning algorithms.
Outline
- Features
- Framework Structure
- Integrated Generative Models
- Integrated Algorithms
- Installation
- Quick Start
- Documentation
- Tutorials
- Benchmark experiments
Features
- Support for training, evaluation and deploying diverse generative models, including diffusion models and flow models
- Integration of generative models for state representation, action representation, policy learning and dynamic model learning in RL
- Implementation of popular RL algorithms tailored for generative models, such as Q-guided policy optimization (QGPO)
- Support for various RL environments and benchmarks
- Easy-to-use API for training and evaluation
Framework Structure
Integrated Generative Models
Score Matching | Flow Matching | |
---|---|---|
Diffusion Model | ||
Linear VP SDE | ✔ | ✔ |
Generalized VP SDE | ✔ | ✔ |
Linear SDE | ✔ | ✔ |
Flow Model | ||
Independent Conditional Flow Matching | 🚫 | ✔ |
Optimal Transport Conditional Flow Matching | 🚫 | ✔ |
Integrated Algorithms
Algo./Models | Diffusion Model | Flow Model |
---|---|---|
QGPO | ✔ | 🚫 |
SRPO | ✔ | 🚫 |
GMPO | ✔ | ✔ |
GMPG | ✔ | ✔ |
Installation
pip install grl
Or, if you want to install from source:
git clone https://github.com/opendilab/GenerativeRL.git
cd GenerativeRL
pip install -e .
Or you can use the docker image:
docker pull opendilab/grl:torch2.3.0-cuda12.1-cudnn8-runtime
docker run -it --rm --gpus all opendilab/grl:torch2.3.0-cuda12.1-cudnn8-runtime /bin/bash
Quick Start
Here is an example of how to train a diffusion model for Q-guided policy optimization (QGPO) in the LunarLanderContinuous-v2 environment using GenerativeRL.
Install the required dependencies:
pip install 'gym[box2d]==0.23.1'
(The gym version can be from 0.23 to 0.25 for box2d environments, but it is recommended to use 0.23.1 for compatibility with D4RL.)
Download dataset from here and save it as data.npz
in the current directory.
GenerativeRL uses WandB for logging. It will ask you to log in to your account when you use it. You can disable it by running:
wandb offline
import gym
from grl.algorithms.qgpo import QGPOAlgorithm
from grl.datasets import QGPOCustomizedDataset
from grl.utils.log import log
from grl_pipelines.diffusion_model.configurations.lunarlander_continuous_qgpo import config
def qgpo_pipeline(config):
qgpo = QGPOAlgorithm(config, dataset=QGPOCustomizedDataset(numpy_data_path="./data.npz", action_augment_num=config.train.parameter.action_augment_num))
qgpo.train()
agent = qgpo.deploy()
env = gym.make(config.deploy.env.env_id)
observation = env.reset()
for _ in range(config.deploy.num_deploy_steps):
env.render()
observation, reward, done, _ = env.step(agent.act(observation))
if __name__ == '__main__':
log.info("config: \n{}".format(config))
qgpo_pipeline(config)
For more detailed examples and documentation, please refer to the GenerativeRL documentation.
Documentation
The full documentation for GenerativeRL can be found at GenerativeRL Documentation.
Tutorials
We provide several case tutorials to help you better understand GenerativeRL. See more at tutorials.
Benchmark experiments
We offer some baseline experiments to evaluate the performance of generative reinforcement learning algorithms. See more at benchmark.
Contributing
We welcome contributions to GenerativeRL! If you are interested in contributing, please refer to the Contributing Guide.
License
GenerativeRL is licensed under the Apache License 2.0. See LICENSE for more details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file generativerl-0.0.1.tar.gz
.
File metadata
- Download URL: generativerl-0.0.1.tar.gz
- Upload date:
- Size: 141.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.1 CPython/3.12.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 911d75a8cf7481b67f93b67b4e4f6a29b364d4edb85a93bbe5fbdf64300fab0b |
|
MD5 | 40596aa7a7d93594669832cb4bfdb589 |
|
BLAKE2b-256 | 29e836697d81b4ce99449e0b4d3ddaa9fa5483f4f9f981e9a5d5d039e45b29aa |
Provenance
The following attestation bundles were made for generativerl-0.0.1.tar.gz
:
Publisher:
release.yml
on opendilab/GenerativeRL
-
Statement type:
https://in-toto.io/Statement/v1
- Predicate type:
https://docs.pypi.org/attestations/publish/v1
- Subject name:
generativerl-0.0.1.tar.gz
- Subject digest:
911d75a8cf7481b67f93b67b4e4f6a29b364d4edb85a93bbe5fbdf64300fab0b
- Sigstore transparency entry: 145539443
- Sigstore integration time:
- Predicate type:
File details
Details for the file GenerativeRL-0.0.1-py3-none-any.whl
.
File metadata
- Download URL: GenerativeRL-0.0.1-py3-none-any.whl
- Upload date:
- Size: 191.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.1 CPython/3.12.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | f18f3577219075ecfca4d72ab82808e3b405ae8c17a0441e17c2c4f3a38832a0 |
|
MD5 | 4b3a8ab1593e0ca665c3abb810183802 |
|
BLAKE2b-256 | 744ff96c36768406c1ad7f69970828ab0a75c21ec1834f7b56fe5f19e9715226 |
Provenance
The following attestation bundles were made for GenerativeRL-0.0.1-py3-none-any.whl
:
Publisher:
release.yml
on opendilab/GenerativeRL
-
Statement type:
https://in-toto.io/Statement/v1
- Predicate type:
https://docs.pypi.org/attestations/publish/v1
- Subject name:
generativerl-0.0.1-py3-none-any.whl
- Subject digest:
f18f3577219075ecfca4d72ab82808e3b405ae8c17a0441e17c2c4f3a38832a0
- Sigstore transparency entry: 145539448
- Sigstore integration time:
- Predicate type: