Skip to main content

Robustness Gym is an evaluation toolkit for machine learning.

Project description

RG logo

Robustness Gym

GitHub Workflow Status GitHub Documentation Status pre-commit website

Robustness Gym is a Python evaluation toolkit for machine learning models.

Getting Started | What is Robustness Gym? | Docs | Contributing | About

Getting started

pip install robustnessgym

Note: some parts of Robustness Gym rely on optional dependencies. If you know which optional dependencies you'd like to install, you can do so using something like pip install robustnessgym[dev,text] instead. See setup.py for a full list of optional dependencies.

What is Robustness Gym?

Robustness Gym is being developed to address challenges in evaluating machine learning models today, with tools to evaluate and visualize the quality of machine learning models.

Along with Meerkat, we make it easy for you to load in any kind of data (text, images, videos, time-series) and quickly evaluate how well your models are performing.

Using Robustness Gym

import robustnessgym as rg

# Load any dataset
sst = rg.DataPanel.from_huggingface('sst', split='validation')

# Load any model
sst_model = rg.HuggingfaceModel('distilbert-base-uncased-finetuned-sst-2-english', is_classifier=True)

# Generate predictions for first 2 examples in dataset using "sentence" column as input
predictions = sst_model.predict_batch(sst[:2], ['sentence'])

# Run inference on an entire dataset & store the predictions in the dataset
sst = sst.update(lambda x: sst_model.predict_batch(x, ['sentence']), batch_size=4, is_batched_fn=True, pbar=True)

# Create a DevBench, which will contain slices to evaluate
sst_db = rg.DevBench()

# Add slices of data; to begin with let's add the full dataset
# Slices are just datasets that you can track performance on
sst_db.add_slices([sst])

# Let's add another slice by filtering examples containing negation words
sst_db(rg.HasNegation(), sst, ['sentence'])

# Add any metrics you like
sst_db.add_aggregators({
    # Map from model name to dictionary of metrics
    'distilbert-base-uncased-finetuned-sst-2-english': {
        # This function uses the predictions we stored earlier to calculate accuracy
        'accuracy': lambda dp: (dp['label'].round() == dp['pred'].numpy()).mean()
    }
})

# Create a report
report = sst_db.create_report()

# Visualize: requires installing plotly support in Jupyter, generally works better in Jupyter notebooks (rather than Jupyter Lab)
report.figure()

# Alternatively, save report to file
report.figure().write_image('sst_db_report.png', engine='kaleido')

Applying Built-in Subpopulations

# Create a slicebuilder that creates subpopulations based on length, in this case the bottom and top 10 percentile.
length_sb = rg.NumTokensSubpopulation(intervals=[("0%", "10%"), ("90%", "100%")])

slices, membership = length_sb(dp=sst, columns=['sentence'])
# `slices` is a list of 2 DataPanel objects
# `membership` is a matrix of shape (n x 2)
for sl in slices:
    print(sl.identifier)

Creating Custom Subpopulations

def length(batch: rg.DataPanel, columns: list):
    return [len(text.split()) for text in batch[columns[0]]]

# Create a subpopulation that buckets examples based on length
length_sp = rg.ScoreSubpopulation(intervals=[(0, 10), (10, 20)], score_fn=length)

slices, membership = length_sp(dp=sst, columns=['sentence'])
for sl in slices:
    print(sl.identifier)

About

You can read more about the ideas underlying Robustness Gym in our paper on arXiv.

The Robustness Gym project began as a collaboration between Stanford Hazy Research, Salesforce Research and UNC Chapel-Hill. We also have a website.

If you use Robustness Gym in your work, please use the following BibTeX entry,

@inproceedings{goel-etal-2021-robustness,
    title = "Robustness Gym: Unifying the {NLP} Evaluation Landscape",
    author = "Goel, Karan  and
      Rajani, Nazneen Fatema  and
      Vig, Jesse  and
      Taschdjian, Zachary  and
      Bansal, Mohit  and
      R{\'e}, Christopher",
    booktitle = "Proceedings of the 2021 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies: Demonstrations",
    month = jun,
    year = "2021",
    address = "Online",
    publisher = "Association for Computational Linguistics",
    url = "https://www.aclweb.org/anthology/2021.naacl-demos.6",
    pages = "42--55",
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

robustnessgym-0.1.3.tar.gz (82.0 kB view details)

Uploaded Source

Built Distribution

robustnessgym-0.1.3-py2.py3-none-any.whl (106.8 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file robustnessgym-0.1.3.tar.gz.

File metadata

  • Download URL: robustnessgym-0.1.3.tar.gz
  • Upload date:
  • Size: 82.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.6.4 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.1 CPython/3.8.11

File hashes

Hashes for robustnessgym-0.1.3.tar.gz
Algorithm Hash digest
SHA256 967443b64aec045afaac0a55b8f1654a3627a34666e8bd6917f72faed9f91909
MD5 1e77db292d6fe63f54934590cbd21471
BLAKE2b-256 e6112f6a385e995bb357dd8c96688e2301a4b5a4f5ca89647bb1746d12745eff

See more details on using hashes here.

File details

Details for the file robustnessgym-0.1.3-py2.py3-none-any.whl.

File metadata

  • Download URL: robustnessgym-0.1.3-py2.py3-none-any.whl
  • Upload date:
  • Size: 106.8 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.6.4 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.1 CPython/3.8.11

File hashes

Hashes for robustnessgym-0.1.3-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 b2d2cec414f1476509a835853c0355cb1bb34f0f95729e8462d9d889d5bad35d
MD5 d4c4384e83f31f698ab8039173c670b3
BLAKE2b-256 f272087a2e4386eae2e25b4a23dc2dc5dde4b0f0fa5bb304878b187b8d597f1b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page