Skip to main content

Transformers Reinforcement and Imitation Learning Library

Project description

TRIL

Transformers Reinforcement and Imitation Learning Library

TRIL is a modular library for Reinforcment Learning (RL) and Imitation Learning (IL) algorithm development with transformers. We directly build on top of transformers, accelerate, and peft libraries by ๐Ÿค— Hugging Face. That way TRIL is able to support open-sourced pretrained models, distributed computing, as well as parameter efficient training. Note we currently support most decoder and encoder-decoder architectures availble in transformers.

Supported Algorithms:

Supported Tasks:


Planned Algorithms:

Planned Tasks:

Installation

To install tril do:

pip install tril

For the run scripts and the example scripts for usage please see the respository.

To setup a development environment we use conda for version control. To install TRIL, please follow these steps

conda create -n tril python=3.10
conda activate tril
pip install -e .

Optionally, for caption_metrics such as CiDER-D and SPICE, please install these additional dependencies.

# Spacy model install
python -m spacy download en_core_web_sm

# CoreNLP library install
cd src/tril/metrics/caption_metrics/spice && bash get_stanford_models.sh

Example Scripts

In the examples directory, there are example scripts to run TRIL algorithms on IMDB positive sentiment generation using pytorch Fully Sharded Data Parallel (FSDP) and TL;DR summarization using deepspeed. The name of each script is of the format, <task>_<alg>.yaml. Run each experiment like the following:

./examples/<task>/<script>

Within each script the command is

accelerate --config <accelerate config> [accelerate args] main.py task=<task config> alg=<alg config> [hydra CLI config specification]

Please see the accelerate launch tutorial for how to launch jobs with accelerate. We provide examples of different accelerate configs in the accelerate_cfgs directoy. For more details on Hydra CLI and config usage please see this tutorial.

Usage Example

Here is a minimal example of running PPO with TRIL:

import hydra
from accelerate import Accelerator
from tril import tril_run
from tril.logging import Tracker
from tril.algorithms import PPO

@hydra.main(version_base=None, config_path="cfgs", config_name="config") # Hydra Decorator for Config
@tril_run # TRIL decorator for hydra config processing
def run_ppo(cfg):
    # Initialize accelerator for distributed computing
    accelerator = Accelerator()

    # Grab experiment save directory from Hydra
    save_path = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir

    # Instantiate TRIL logger for WandB and CLI logging/saving
    tracker = Tracker(
        save_path,
        OmegaConf.to_container(cfg, resolve=True),
        cfg.project_name,
        cfg.experiment_name,
        cfg.entity_name,
        cfg.log_to_wandb,
        log_level=logging.INFO,
        is_main_process=accelerator.is_main_process,
    )

    # Instantiate Algorithm
    ppo = PPO(cfg, accelerator, tracker)

    # Start learn to train LLM
    ppo.learn()

if __name__ == '__main__':
    run_ppo()

TRIL also provides an AlgorithmRegistry to instantiate algorithms. Please see our main.py to see how our scripts instantiate the algorithms. The list of available algorithms can be seen by the configs in cfgs/task.

Current Task/Algorithm Support Matrix

Algorithm IMDB CommonGen TL;DR
PPO โœ… โœ… โœ…
PPO++ โœ… โœ… โœ…
AggreVaTeD โœ… โœ… โœ…
LOLS โœ… โœ… โœ…
D2LOLS โœ… โœ… โœ…
BC โœ… โœ… โœ…
GAIL โœ…

Code Structure

The directory structure of the configs, run script, and TRIL components looks like this.

โ”œโ”€โ”€ cfgs                    <- Hydra configs
โ”‚   โ”œโ”€โ”€ alg                 <- Algorithm configs (e.g. PPO)
โ”‚   โ”œโ”€โ”€ task                <- Task configs (e.g. TL;DR summarization)
โ”‚   โ”œโ”€โ”€ logging             <- Logging configs (e.g. WandB)
โ”‚   โ”‚
โ”‚   โ””โ”€โ”€ config.yaml         <- Main config for training
โ”‚
โ”œโ”€โ”€ accelerate_cfgs         <- Accelerate configs
โ”‚
โ”œโ”€โ”€ main.py                 <- TRIL main function
โ”‚
โ”œโ”€โ”€ tril                    <- TRIL src
โ”‚   โ”œโ”€โ”€ algorithms          <- Algorithm implementations
โ”‚   โ”œโ”€โ”€ buffers             <- Data Buffer (e.g. OnlineBuffer, PromptBuffer)
โ”‚   โ”œโ”€โ”€ metrics             <- Evaluation Metrics
โ”‚   โ”œโ”€โ”€ policies            <- Language Model Policies (e.g. Actor, ActorCritic)
โ”‚   โ”œโ”€โ”€ rewards             <- Reward Functions
โ”‚   โ”œโ”€โ”€ tasks               <- Supported Tasks
โ”‚   โ”œโ”€โ”€ utils               <- Helper functions for TRIL
โ”‚   โ”‚
โ”‚   โ”œโ”€โ”€ agent.py            <- Agent contains all torch.nn Modules (i.e. Policy and Reward)
โ”‚   โ”œโ”€โ”€ base_algorithm.py   <- Algorithm abstract class
โ”‚   โ”œโ”€โ”€ base_metric.py      <- Metric abstract class
โ”‚   โ”œโ”€โ”€ base_reward.py      <- Reward abstract class
โ”‚   โ”œโ”€โ”€ base_task.py        <- Task abstract class
โ”‚   โ””โ”€โ”€ logging.py          <- TRIL Logger

In each directory's __init__.py, there is a registry to register all supported algorithms, metrics, rewards, and tasks. When extending TRIL, please add the respective addition to one of these registries.

Logging

TRIL support Weights and Biases logging. Please enter your wandb details such as entity_name and project_name into cfgs/logging/wandb.yaml. If you would not like to log to wandb, please set log_to_wandb=False.

By default, we save training and evaluation information in outputs/<experiment_name>/<datetime> You can define experiment_name in cfgs/config.yaml or through Hydra CLI, main.py experiment_name=<name>.

Example WandB Reports

Here is an example WandB Report of how the logging would look like when running multiple different algorithms

Citing TRIL

If you use TRIL in your publication, please cite it by using the following BibTeX entry.

@misc{TRIL,
      title={TRIL: Transformers Reinforcement and Imitation Learning Library},
      author={Jonathan D Chang and Kiante Brantley and Rajkumar Ramamurthy and Dipendra Misra and Wen Sun},
      howpublished={\url{https://github.com/Cornell-RL/tril}},
      year={2023}
}

Here is the citation of the accompanying paper for many of the supported algorithms.

@misc{chang2023learning,
      title={Learning to Generate Better Than Your LLM}, 
      author={Jonathan D. Chang and Kiante Brantley and Rajkumar Ramamurthy and Dipendra Misra and Wen Sun},
      year={2023},
      eprint={2306.11816},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Acknowledgements

We would like to acknowledge RL4LMs, TRL, and TRLx for being inspirations for this library.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tril-0.2.1.tar.gz (67.8 kB view details)

Uploaded Source

Built Distribution

tril-0.2.1-py3-none-any.whl (87.2 kB view details)

Uploaded Python 3

File details

Details for the file tril-0.2.1.tar.gz.

File metadata

  • Download URL: tril-0.2.1.tar.gz
  • Upload date:
  • Size: 67.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.13

File hashes

Hashes for tril-0.2.1.tar.gz
Algorithm Hash digest
SHA256 efb4523576892421df0b80d8cbd49e843d9856d5a47816f3f984ef07c9f5cb05
MD5 d1df14d748310f2392410b83365aeae3
BLAKE2b-256 cef9f161645a242967fda7456d44b47cbeeb617069889e4e835543e7227c9db7

See more details on using hashes here.

File details

Details for the file tril-0.2.1-py3-none-any.whl.

File metadata

  • Download URL: tril-0.2.1-py3-none-any.whl
  • Upload date:
  • Size: 87.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.13

File hashes

Hashes for tril-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 97c525b752f64dd8443662b220cca8098539c08815e587deae8df6d29917e9cb
MD5 b2fb7497b204478332f5a57f0ef65a54
BLAKE2b-256 0b2677b9e4c1d22ccad72ff0718d9997b18c32e4f892274eb12f64e08d2e66ec

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page