Skip to main content

A package for creating IL datasets

Project description

IL Datasets

cover

Hi, welcome to the Imitation Learning (IL) Datasets. Something that always bothered me a lot was how difficult it was to find good weights for an expert, trying to create a dataset for different state-of-the-art methods, and also having to run all methods due to no common datasets. For these reasons, I've created this repository in an effort to make it more accessible for researchers to create datasets using experts from the Hugging Face. IL-Datasets provides teacher weights for different environments, a multi-threading solution for creating datasets faster, datasets for a set of environments (check the bottom of this document to see which environments are already released), and a benchmark for common imitation learning methods. We hope that by releasing these features, we can make the barrier to learning and researching imitation learning more accessible.

This project is under development. If you are interested in helping, feel free to contact me.

Requirements

The project supports Python versions 3.8~3.11. All requirements for the imitation_datasets package are listed in requirements.txt. These requirements are required by the package and are installed together with the IL-Datasets. For requirements to use the benchmark package, use both the imitation_datasets requirements and the ones listed in benchmark.txt. Development requirements are listed at dev.txt. We do not recommend using these dependencies outside development. They use an outdated version from gym v0.21.0 to test the GymWrapper class.

Install

IL-Datasets doesn't download its PyTorch and Gym dependencies so it doesn't force users to use specific version. We test IL-Datasets using pytorch@latest, gymnasium@latest and gym@v0.21.0. If there is any issue with a different version, please open an issue so we might take a look.

The package is available on PyPi:

# Stable version
pip install il-datasets

But if you prefer, you can install it from the source.

git clone https://github.com/NathanGavenski/IL-Datasets.git
cd IL-Datasets
pip install -e .

How does it work?

This project also works with multithreading, which should accelerate the dataset creation. It consists of one Controller class, which requires two different functions to work: (i) a enjoy function (for the agent to play and record an episode); and a (ii) collate function (for putting all episodes together).


The enjoy function will receive 3 parameters and return 1:

"""
Args:
   path (str): where the episode is going to be recorded
   experiment (Context): A class for recording all information (if you don't want to use `print` - keeping the console clear)
   expert (Policy): A model based on the [StableBaselines3](https://stable-baselines3.readthedocs.io/en/master/) `BaseAlgorithm`.

Returns:
   status (bool): Whether it was successful or not
"""

Obs: To use the model you can call predict, the policy class already has the correct form of using it (a.k.a., how the StableBaselines3 uses).


The collate function will receive 2 parameters and return 1:

"""
Args:
   path (str): Where it should save the final dataset
   episodes  (list[str]): A list of paths for each file

Returns:
   status (bool): Whether it was successful or not
"""

Default functions

The imitation_datasets package also comes with a set of default functions, so you don't need to implement a enjoy and a collate function in every project. The resulting dataset will be a NpzFile with the following data:

"""
Data:
   obs (list[list[float]): gym environment observation. Size [steps, observations space].
   actions (list[float]): agent action. Size [steps, action] (1 if single action, n if multiple actions).
   rewards (list[int]): reward from the action with the observations (e.g., r(obs, action)). Size [steps, ].
   episode_returns (list[float]): accumulated reward for each episode. Size [number of peisodes, ].
   episode_starts (list[bool]): whether the episode started at the current observation. Size [steps, ].
"""

A small functional example of how to use the given functions:

# python <script> --game cartpole --threads 4 --episodes 1000 --mode all
from imitation_datasets.functions import baseline_enjoy, baseline_collate
from imitation_datasets.controller import Controller
from imitation_datasets.args import get_args

args = get_args()
controller = Controller(baseline_enjoy, baseline_collate, args.episodes, args.threads)
controller.start(args)

This script will use the pre-registered CartPole-v1 environment with the HuggingFace weights and create a teacher.npz dataset file in ./dataset/cartpole/teacher.npz.

Registered environments

IL-Datasets comes with some already registered weights from HuggingFace. To check which environments are already registered, check under the src.imitation_datasets.registers folder.

Registering new experts

If you would like to add new experts locally, you can call the Experts class. It uses the following structure:

"""
Args:
   identifier (str): Name for calling the expert (e.g., cartpole).
   Policy (Policy): a dataclass with:
      name (str): Gym environment name
      repo_id (str): HuggingFace repo identification
      filename (str): HuggingFace weights file name
      threshold (float): How much reward should the episode accumulate to be considered good
      algo (BaseAlgorithm): The class from StableBaselines3
"""

If not using StableBaselines, you can load a Policy and not call the load() function (which downloads weights from HuggingFace). Moreover, the expert has to have a predict function that receives:

"""
Args:
   obs (Tensor): current environment state
   state (Tensor): Model's internal state
   deterministic (bool): if it should explore or not.
"""

Datasets

The IL-Datasets also come with a default PyTorch dataset, called BaselineDataset. It uses the pattern set by the baseline_collate function, and it allows the use of HuggingFace datasets created by the baseline_to_huggingface function. The dataset list for benchmarking is under development, so to check all new versions, you can visit our collection on HuggingFace.

To use the Baseline dataset, you can use a local file:

from src.imitation_datasets.dataset import BaselineDataset
BaselineDataset(f"./dataset/cartpole/teacher.npz")

Or a HuggingFace path:

from src.imitation_datasets.dataset import BaselineDataset
BaselineDataset(f"NathanGavenski/CartPole-v1", source="huggingface")

Finally, the dataset allows for fewer episodes and splitting for evaluation and train.

from src.imitation_datasets.dataset import BaselineDataset
dataset_train = BaselineDataset(f"NathanGavenski/CartPole-v1", source="huggingface", n_episodes=100)
dataset_eval = BaselineDataset(f"NathanGavenski/CartPole-v1", source="huggingface", n_episodes=100, split="eval")

Benchmark

Last but not least, IL-Datasets comes with its own benchmarking. It uses IL methods on already published datasets to provide consistent results for researchers who also use our datasets. Currently, we support:

Algorithm Implementation Benchmark
Behavioural Cloning benchmark.methods.bc
Behavioural Cloning from Observation benchmark.methods.bco
Augmented Behavioural Cloning from Observation benchmark.methods.abco
Imitating Unknown Policies via Exploration benchmark.methods.iupe

However, our plan is to implement more state-of-the-art methods.

You can check the current benchmark results at benchmark_results.md.


This repository is under development

Here is a list of the upcoming releases:

  • Benchmark methods
    • Behavioural Cloning
    • Behavioural Cloning from Observation
    • Imitating Latent Policies from Observation
    • Augmented Behavioural Cloning from Observation
    • Imitating Unkown Policies via Exploration
    • Generative Adversarial Imitation Learning
    • Generative Adversarial Imitation Learning from Observation
    • Off-Policy Imitation Learning from Observations
    • Model-Based Imitation Learning From Observation Alone
    • Self-Supervised Adversarial Imitation Learning
  • Benchmark environments
    • CartPole-v1
    • MountainCar-v0
    • Acrobot-v1
    • LunarLander-v2
    • Ant-v3
    • Hopper-v3
    • HalfCheetah-v3
    • Walker-v3
    • Humanoid-v3
    • Swimmer-v3

Although there are a lot of environments and methods to go through, this repository will be considered done once the documentation and the installation of the benchmarks are done. We don't have a plan for releases for environments and methods yet.

Cite

If you used IL-Datasets on your research and would like to cite us:

@inproceedings{gavenski2024ildatasets,
  title={Imitation Learning Datasets: A Toolkit For Creating Datasets, Training Agents and Benchmarking},
  author={Gavenski, Nathan and Luck, Michael and Rodrigues, Odinaldo},
  booktitle={Proceedings of the 2024 International Conference on Autonomous Agents and Multiagent Systems},
  year={2024}
}

If you like this repository, be sure to check my other projects:

Development-based

Academic

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

il-datasets-0.6.1.tar.gz (17.7 kB view details)

Uploaded Source

Built Distribution

il_datasets-0.6.1-py3-none-any.whl (21.7 kB view details)

Uploaded Python 3

File details

Details for the file il-datasets-0.6.1.tar.gz.

File metadata

  • Download URL: il-datasets-0.6.1.tar.gz
  • Upload date:
  • Size: 17.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.12

File hashes

Hashes for il-datasets-0.6.1.tar.gz
Algorithm Hash digest
SHA256 ad8e9948883ccf33935bbc30ea0b574c612bb4599c102e915ccd766bfdd4633d
MD5 b2204b17f52601e998af1288b864e824
BLAKE2b-256 423b3f5efda1964aa3544297a01ebb0d72ad54137d82af0b038b2a9b2eb8ecfc

See more details on using hashes here.

File details

Details for the file il_datasets-0.6.1-py3-none-any.whl.

File metadata

  • Download URL: il_datasets-0.6.1-py3-none-any.whl
  • Upload date:
  • Size: 21.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.12

File hashes

Hashes for il_datasets-0.6.1-py3-none-any.whl
Algorithm Hash digest
SHA256 aa6df7c82d3f1fe8e986c9768409fcdccf5dd93271965f9dcc59c856eefd0987
MD5 6f38f3312d5e72d00f42256fb0dd839f
BLAKE2b-256 41d42c003096520e50b713eca0160dc72bba28c177311dd4d034ab72b75c6b29

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page