Skip to main content

PyTorch implementations of generative reinforcement learning algorithms

Project description

Generative Reinforcement Learning (GRL)

Twitter
GitHub stars GitHub forks GitHub commit activity GitHub issues GitHub pulls Contributors License

English | 简体中文(Simplified Chinese)

GenerativeRL, short for Generative Reinforcement Learning, is a Python library for solving reinforcement learning (RL) problems using generative models, such as diffusion models and flow models. This library aims to provide a framework for combining the power of generative models with the decision-making capabilities of reinforcement learning algorithms.

Outline

Features

  • Support for training, evaluation and deploying diverse generative models, including diffusion models and flow models
  • Integration of generative models for state representation, action representation, policy learning and dynamic model learning in RL
  • Implementation of popular RL algorithms tailored for generative models, such as Q-guided policy optimization (QGPO)
  • Support for various RL environments and benchmarks
  • Easy-to-use API for training and evaluation

Framework Structure

Image Description 1

Integrated Generative Models

Score Matching Flow Matching
Diffusion Model Open In Colab
Linear VP SDE
Generalized VP SDE
Linear SDE
Flow Model Open In Colab
Independent Conditional Flow Matching 🚫
Optimal Transport Conditional Flow Matching 🚫

Integrated Algorithms

Algo./Models Diffusion Model Flow Model
QGPO 🚫
SRPO 🚫
GMPO Open In Colab
GMPG Open In Colab

Installation

pip install grl

Or, if you want to install from source:

git clone https://github.com/opendilab/GenerativeRL.git
cd GenerativeRL
pip install -e .

Or you can use the docker image:

docker pull opendilab/grl:torch2.3.0-cuda12.1-cudnn8-runtime
docker run -it --rm --gpus all opendilab/grl:torch2.3.0-cuda12.1-cudnn8-runtime /bin/bash

Quick Start

Here is an example of how to train a diffusion model for Q-guided policy optimization (QGPO) in the LunarLanderContinuous-v2 environment using GenerativeRL.

Install the required dependencies:

pip install 'gym[box2d]==0.23.1'

(The gym version can be from 0.23 to 0.25 for box2d environments, but it is recommended to use 0.23.1 for compatibility with D4RL.)

Download dataset from here and save it as data.npz in the current directory.

GenerativeRL uses WandB for logging. It will ask you to log in to your account when you use it. You can disable it by running:

wandb offline
import gym

from grl.algorithms.qgpo import QGPOAlgorithm
from grl.datasets import QGPOCustomizedDataset
from grl.utils.log import log
from grl_pipelines.diffusion_model.configurations.lunarlander_continuous_qgpo import config

def qgpo_pipeline(config):
    qgpo = QGPOAlgorithm(config, dataset=QGPOCustomizedDataset(numpy_data_path="./data.npz", action_augment_num=config.train.parameter.action_augment_num))
    qgpo.train()

    agent = qgpo.deploy()
    env = gym.make(config.deploy.env.env_id)
    observation = env.reset()
    for _ in range(config.deploy.num_deploy_steps):
        env.render()
        observation, reward, done, _ = env.step(agent.act(observation))

if __name__ == '__main__':
    log.info("config: \n{}".format(config))
    qgpo_pipeline(config)

For more detailed examples and documentation, please refer to the GenerativeRL documentation.

Documentation

The full documentation for GenerativeRL can be found at GenerativeRL Documentation.

Tutorials

We provide several case tutorials to help you better understand GenerativeRL. See more at tutorials.

Benchmark experiments

We offer some baseline experiments to evaluate the performance of generative reinforcement learning algorithms. See more at benchmark.

Contributing

We welcome contributions to GenerativeRL! If you are interested in contributing, please refer to the Contributing Guide.

License

GenerativeRL is licensed under the Apache License 2.0. See LICENSE for more details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

generativerl-0.0.1.tar.gz (141.1 kB view details)

Uploaded Source

Built Distribution

GenerativeRL-0.0.1-py3-none-any.whl (191.3 kB view details)

Uploaded Python 3

File details

Details for the file generativerl-0.0.1.tar.gz.

File metadata

  • Download URL: generativerl-0.0.1.tar.gz
  • Upload date:
  • Size: 141.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for generativerl-0.0.1.tar.gz
Algorithm Hash digest
SHA256 911d75a8cf7481b67f93b67b4e4f6a29b364d4edb85a93bbe5fbdf64300fab0b
MD5 40596aa7a7d93594669832cb4bfdb589
BLAKE2b-256 29e836697d81b4ce99449e0b4d3ddaa9fa5483f4f9f981e9a5d5d039e45b29aa

See more details on using hashes here.

Provenance

The following attestation bundles were made for generativerl-0.0.1.tar.gz:

Publisher: release.yml on opendilab/GenerativeRL

Attestations:

File details

Details for the file GenerativeRL-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: GenerativeRL-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 191.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for GenerativeRL-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 f18f3577219075ecfca4d72ab82808e3b405ae8c17a0441e17c2c4f3a38832a0
MD5 4b3a8ab1593e0ca665c3abb810183802
BLAKE2b-256 744ff96c36768406c1ad7f69970828ab0a75c21ec1834f7b56fe5f19e9715226

See more details on using hashes here.

Provenance

The following attestation bundles were made for GenerativeRL-0.0.1-py3-none-any.whl:

Publisher: release.yml on opendilab/GenerativeRL

Attestations:

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page