Skip to main content

Transformers Reinforcement and Imitation Learning Library

Project description

TRIL

Transformers Reinforcement and Imitation Learning Library

TRIL is a modular library for Reinforcment Learning (RL) and Imitation Learning (IL) algorithm development with transformers. We directly build on top of transformers, accelerate, and peft libraries by ๐Ÿค— Hugging Face. That way TRIL is able to support open-sourced pretrained models, distributed computing, as well as paramter efficient training. Note we currently support most decoder and encoder-decoder architectures availble in transformers.

Algorithms:

Planned Algorithms:

Installation

To install tril do:

pip install tril

For the run scripts and the example scripts for usage please see the respository.

To setup a development environment we use conda for version control. To install TRIL, please follow these steps

conda create -n tril python=3.10
conda activate tril
pip install -e .

Optionally, for caption_metrics such as CiDER-D and SPICE, please install these additional dependencies.

# Spacy model install
python -m spacy download en_core_web_sm

# CoreNLP library install
cd src/tril/metrics/caption_metrics/spice && bash get_stanford_models.sh

Example Scripts

In the examples directory, there are example scripts to run TRIL algorithms on IMDB positive sentiment generation using pytorch Fully Sharded Data Parallel (FSDP) and TL;DR summarization using deepspeed. The name of each script is of the format, <task>_<alg>.yaml. Run each experiment like the following:

./examples/<script>

Within each script the command is

accelerate --config <accelerate config> [accelerate args] main.py task=<task config> alg=<alg config> [hydra CLI config specification]

Please see the accelerate launch tutorial for how to launch jobs with accelerate. We provide examples of different accelerate configs in the accelerate_cfgs directoy. For more details on Hydra CLI and config usage please see this tutorial.

Usage Example

Here is a minimal example of running PPO with TRIL:

import hydra
from accelerate import Accelerator
from tril import tril_run
from tril.logging import Tracker
from tril.algorithms import PPO

@hydra.main(version_base=None, config_path="cfgs", config_name="config") # Hydra Decorator for Config
@tril_run # TRIL decorator for hydra config processing
def run_ppo(cfg):
    # Initialize accelerator for distributed computing
    accelerator = Accelerator()

    # Grab experiment save directory from Hydra
    save_path = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir

    # Instantiate TRIL logger for WandB and CLI logging/saving
    tracker = Tracker(
        save_path,
        OmegaConf.to_container(cfg, resolve=True),
        cfg.project_name,
        cfg.experiment_name,
        cfg.entity_name,
        cfg.log_to_wandb,
        log_level=logging.INFO,
        is_main_process=accelerator.is_main_process,
    )

    # Instantiate Algorithm
    ppo = PPO(cfg, accelerator, tracker)

    # Start learn to train LLM
    ppo.learn()

if __name__ == '__main__':
    run_ppo()

TRIL also provides an AlgorithmRegistry to instantiate algorithms. Please see our main.py to see how our scripts instantiate the algorithms. The list of available algorithms can be seen by the configs in cfgs/task.

Current Task/Algorithm Support Matrix

Algorithm IMDB CommonGen TL;DR
PPO โœ… โœ… โœ…
PPO++ โœ… โœ… โœ…
AggreVaTeD โœ… โœ… โœ…
LOLS โœ… โœ…
D2LOLS โœ… โœ…
BC โœ…
GAIL โœ…

Code Structure

The directory structure of the configs, run script, and TRIL components looks like this.

โ”œโ”€โ”€ cfgs                    <- Hydra configs
โ”‚   โ”œโ”€โ”€ alg                 <- Algorithm configs (e.g. PPO)
โ”‚   โ”œโ”€โ”€ task                <- Task configs (e.g. TL;DR summarization)
โ”‚   โ”œโ”€โ”€ logging             <- Logging configs (e.g. WandB)
โ”‚   โ”‚
โ”‚   โ””โ”€โ”€ config.yaml         <- Main config for training
โ”‚
โ”œโ”€โ”€ accelerate_cfgs         <- Accelerate configs
โ”‚
โ”œโ”€โ”€ main.py                 <- TRIL main function
โ”‚
โ”œโ”€โ”€ tril                    <- TRIL src
โ”‚   โ”œโ”€โ”€ algorithms          <- Algorithm implementations
โ”‚   โ”œโ”€โ”€ buffers             <- Data Buffer (e.g. OnlineBuffer, PromptBuffer)
โ”‚   โ”œโ”€โ”€ metrics             <- Evaluation Metrics
โ”‚   โ”œโ”€โ”€ policies            <- Language Model Policies (e.g. Actor, ActorCritic)
โ”‚   โ”œโ”€โ”€ rewards             <- Reward Functions
โ”‚   โ”œโ”€โ”€ tasks               <- Supported Tasks
โ”‚   โ”œโ”€โ”€ utils               <- Helper functions for TRIL
โ”‚   โ”‚
โ”‚   โ”œโ”€โ”€ agent.py            <- Agent contains all torch.nn Modules (i.e. Policy and Reward)
โ”‚   โ”œโ”€โ”€ base_algorithm.py   <- Algorithm abstract class
โ”‚   โ”œโ”€โ”€ base_metric.py      <- Metric abstract class
โ”‚   โ”œโ”€โ”€ base_reward.py      <- Reward abstract class
โ”‚   โ”œโ”€โ”€ base_task.py        <- Task abstract class
โ”‚   โ””โ”€โ”€ logging.py          <- TRIL Logger

In each directory's __init__.py, there is a registry to register all supported algorithms, metrics, rewards, and tasks. When extending TRIL, please add the respective addition to one of these registries.

Logging

TRIL support Weights and Biases logging. Please enter your wandb details such as entity_name and project_name into cfgs/logging/wandb.yaml. If you would not like to log to wandb, please set log_to_wandb=False.

By default, we save training and evaluation information in outputs/<experiment_name>/<datetime> You can define experiment_name in cfgs/config.yaml or through Hydra CLI, main.py experiment_name=<name>.

Example WandB Reports

Here is an example WandB Report of how the logging would look like when running multiple different algorithms

Citing TRIL

If you use TRIL in your publication, please cite it by using the following BibTeX entry.

@Misc{TRIL,
  title =        {TRIL: Transformers Reinforcement and Imitation Learning Library},
  author =       {Kiante Brantley, Jonathan Chang, and Wen Sun},
  howpublished = {\url{https://github.com/Cornell-RL/tril}},
  year =         {2023}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tril-0.2.0.tar.gz (66.4 kB view details)

Uploaded Source

Built Distribution

tril-0.2.0-py3-none-any.whl (86.3 kB view details)

Uploaded Python 3

File details

Details for the file tril-0.2.0.tar.gz.

File metadata

  • Download URL: tril-0.2.0.tar.gz
  • Upload date:
  • Size: 66.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.13

File hashes

Hashes for tril-0.2.0.tar.gz
Algorithm Hash digest
SHA256 dcafcddf52bfaa0acac839413f2c86043f4bb72b963f740001ec5df4ce4cfacd
MD5 31802911d60b20cbd98972234e5e3fe0
BLAKE2b-256 2d4af9cc385b5107deb5f82e5341e7649713844c4ec8bc332c17537a1655f0aa

See more details on using hashes here.

File details

Details for the file tril-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: tril-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 86.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.13

File hashes

Hashes for tril-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 882f4b30aead481cf32b6eb21ec61d5558e2caa967d9d4f087d1b433eea1b630
MD5 27e0a34b73b7f955487277ab035b6c5a
BLAKE2b-256 d56438458d2be0bf7011ac3644ac045d7e78b22a2dd21ba689917b9a5e87f348

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page