Robustness Gym is an evaluation toolkit for machine learning.
Project description
Robustness Gym
Robustness Gym is a Python evaluation toolkit for machine learning models.
Getting Started | What is Robustness Gym? | Docs | Contributing | About
Getting started
pip install robustnessgym
Note: some parts of Robustness Gym rely on optional dependencies. If you know which optional dependencies you'd like to install, you can do so using something like
pip install robustnessgym[dev,text]
instead. Seesetup.py
for a full list of optional dependencies.
What is Robustness Gym?
Robustness Gym is being developed to address challenges in evaluating machine learning models today, with tools to evaluate and visualize the quality of machine learning models.
Along with Meerkat, we make it easy for you to load in any kind of data (text, images, videos, time-series) and quickly evaluate how well your models are performing.
Using Robustness Gym
import robustnessgym as rg
# Load any dataset
sst = rg.DataPanel.from_huggingface('sst', split='validation')
# Load any model
sst_model = rg.HuggingfaceModel('distilbert-base-uncased-finetuned-sst-2-english', is_classifier=True)
# Generate predictions for first 2 examples in dataset using "sentence" column as input
predictions = sst_model.predict_batch(sst[:2], ['sentence'])
# Run inference on an entire dataset & store the predictions in the dataset
sst = sst.update(lambda x: sst_model.predict_batch(x, ['sentence']), batch_size=4, is_batched_fn=True, pbar=True)
# Create a DevBench, which will contain slices to evaluate
sst_db = rg.DevBench()
# Add slices of data; to begin with let's add the full dataset
# Slices are just datasets that you can track performance on
sst_db.add_slices([sst])
# Let's add another slice by filtering examples containing negation words
sst_db(rg.HasNegation(), sst, ['sentence'])
# Add any metrics you like
sst_db.add_aggregators({
# Map from model name to dictionary of metrics
'distilbert-base-uncased-finetuned-sst-2-english': {
# This function uses the predictions we stored earlier to calculate accuracy
'accuracy': lambda dp: (dp['label'].round() == dp['pred'].numpy()).mean()
}
})
# Create a report
report = sst_db.create_report()
# Visualize: requires installing plotly support in Jupyter, generally works better in Jupyter notebooks (rather than Jupyter Lab)
report.figure()
# Alternatively, save report to file
report.figure().write_image('sst_db_report.png', engine='kaleido')
Applying Built-in Subpopulations
# Create a slicebuilder that creates subpopulations based on length, in this case the bottom and top 10 percentile.
length_sb = rg.NumTokensSubpopulation(intervals=[("0%", "10%"), ("90%", "100%")])
slices, membership = length_sb(dp=sst, columns=['sentence'])
# `slices` is a list of 2 DataPanel objects
# `membership` is a matrix of shape (n x 2)
for sl in slices:
print(sl.identifier)
Creating Custom Subpopulations
def length(batch: rg.DataPanel, columns: list):
return [len(text.split()) for text in batch[columns[0]]]
# Create a subpopulation that buckets examples based on length
length_sp = rg.ScoreSubpopulation(intervals=[(0, 10), (10, 20)], score_fn=length)
slices, membership = length_sp(dp=sst, columns=['sentence'])
for sl in slices:
print(sl.identifier)
About
You can read more about the ideas underlying Robustness Gym in our paper on arXiv.
The Robustness Gym project began as a collaboration between Stanford Hazy Research, Salesforce Research and UNC Chapel-Hill. We also have a website.
If you use Robustness Gym in your work, please use the following BibTeX entry,
@inproceedings{goel-etal-2021-robustness,
title = "Robustness Gym: Unifying the {NLP} Evaluation Landscape",
author = "Goel, Karan and
Rajani, Nazneen Fatema and
Vig, Jesse and
Taschdjian, Zachary and
Bansal, Mohit and
R{\'e}, Christopher",
booktitle = "Proceedings of the 2021 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies: Demonstrations",
month = jun,
year = "2021",
address = "Online",
publisher = "Association for Computational Linguistics",
url = "https://www.aclweb.org/anthology/2021.naacl-demos.6",
pages = "42--55",
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for robustnessgym-0.1.3-py2.py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | b2d2cec414f1476509a835853c0355cb1bb34f0f95729e8462d9d889d5bad35d |
|
MD5 | d4c4384e83f31f698ab8039173c670b3 |
|
BLAKE2b-256 | f272087a2e4386eae2e25b4a23dc2dc5dde4b0f0fa5bb304878b187b8d597f1b |