Skip to main content

Dataset Interfaces

Project description

Dataset Interfaces

This repository contains the code for our recent work:

Dataset Interfaces: Diagnosing Model Failures Using Controllable Counterfactual Generation
Joshua Vendrow*, Saachi Jain*, Logan Engstrom, Aleksander Madry
Paper: https://arxiv.org/abs/2302.07865
Blog post: https://gradientscience.org/dataset-interfaces/

Getting started

Install using pip, or clone our repository.

pip install dataset-interfaces

Example: For a walkthrough of codebase, check out our example notebook. This notebook shows how to construct a dataset interface for a subset of ImageNet and generate counterfactual examples.

Before running run_textual_inverson, initialize an 🤗Accelerate environment with:

accelerate config

Constructing a Dataset Interface

Constructing a dataset interface consists or learning a class token for each class in a datset, which can then be included in textual prompts.

To learn a single token, we use the following function:

from dataset_interfaces import run_textual_inversion

embed = run_textual_inversion (
    train_path=train_path,  # path to directory with training set for a single class
    token=token,            # text to use for new token, e.g "<plate>"
    class_name=class_name,  # natrual language class description, e.g., "plate"
)

Once all the class tokens are learned, we can create a custom tokenizer and text encoder with these tokens:

import inference_utils as infer_utils

infer_utils.create_encoder (
    embeds=embeds,             # list of learned embeddings (from the code block above)
    tokens=tokens,             # list of token strings
    class_names=class_names,   # list of natural language class descriptions
    encoder_root=encoder_root  # path where to store the tokenizer and encoder
)

Generating Counterfactual Examples

We can now generate counterfactual examples by incorporating our learned tokens in textual prompts. The generate function generates images for a specific class in the dataset (indexed in the order that classes are passed when constructing the encoder). When specifying the text prompt, "" acts as a placeholder for the class token.

from dataset_interfaces import generate

generate (
    encoder_root=encoder_root,
    c=c,                                          # index of a specific class
    prompts="a photo of a <TOKEN> in the grass",  # can be a single prompt or a list of prompts
    num_samples=10, 
    random_seed=0                                 # no seed by default
)

CLIP Metric

To directly evaluate the quality of the generated image, we use CLIP similarity to quantify the presence of the object of interest and desired distribution shift in the image.

We can measure CLIP similarity between a set of generated images and a given caption as follows:

sim_class = infer_utils.clip_similarity(imgs, "a photo of a dog")
sim_shift = infer_utils.clip_similarity(imgs, "a photo in the grass")

ImageNet* Benchmark

Our benchmark for the ImageNet dataset consists of two components: our 1,000 learned class tokens for ImageNet, and the images generated by these tokens in 23 distribution shifts.

ImageNet* Tokens

The 1,000 learned tokens are avaiable on HuggingFace and can be downloaded with:

wget https://huggingface.co/datasets/madrylab/imagenet-star-tokens/resolve/main/tokens.zip

To generate images with these tokens, we first create a text encoder with the tokens, which we use to seamlessly integrate the tokens in text prompts:

token_path = "./tokens". # path to the tokens from HuggingFace
infer_utils.create_imagenet_star_encoder(token_path, encoder_root="./encoder_root_imagenet")

Now, we can generate counterfactual examples of ImageNet from a textual prompt (See the example notebook for a walk-through):

from dataset_interfaces import generate

encoder_root = "./encoder_root_imagenet"
c = 207  # the class for golden retriever
prompt = "a photo of a <TOKEN> wearing a hat"
generate(encoder_root, c, prompt, num_samples=10)

ImageNet* Images

Our benchmark contains images in 23 distribution shifts, with 50k images per shift (50 per class for 1000 classes). These images are also available on HuggingFace. In this repo we also provide masks for each distribution shift indicating which images we filter out with our CLIP metrics, at masks.npy.

We provide a wrapper on top torchvision.datasets.ImageFolder to construct a dataset object that filters the images o=un the benchmark using this mask. So, we can make a dataset object for a shift as follows:

from dataset_interfaces import utils

root = "./imagenet-star"     # the path where the dataset from HuggingFace
mask_path = "./masks.npy"    # the path to the mask file
shift = "in_the_snow"        # the distribution shift of interest

ds = utils.ImageNet_Star_Dataset(
    root, 
    shift=shift,
    mask_path=mask_path
)

Citation

To cite this paper, please use the following BibTex entry:

@inproceedings{vendrow2023dataset,
   title = {Dataset Interfaces: Diagnosing Model Failures Using Controllable Counterfactual Generation},
   author = {Joshua Vendrow and Saachi Jain and Logan Engstrom and Aleksander Madry}, 
   booktitle = {ArXiv preprint arXiv:2302.07865},
   year = {2023}
}

Maintainers:

Josh Vendrow
Saachi Jain

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dataset_interfaces-0.1.0.tar.gz (41.4 kB view details)

Uploaded Source

Built Distribution

dataset_interfaces-0.1.0-py2.py3-none-any.whl (40.7 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file dataset_interfaces-0.1.0.tar.gz.

File metadata

  • Download URL: dataset_interfaces-0.1.0.tar.gz
  • Upload date:
  • Size: 41.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.13

File hashes

Hashes for dataset_interfaces-0.1.0.tar.gz
Algorithm Hash digest
SHA256 40ca0a8aef1c7153109a706b0e77d9d5d4e782b0640baebe4935189f66b52106
MD5 36f0b4ec08c13da3e919eb6056e33b9b
BLAKE2b-256 9d71c9c5ba1c7ad15c0a20c691d1ca825f4ad4f60658da3a9f9e309e31adc5eb

See more details on using hashes here.

File details

Details for the file dataset_interfaces-0.1.0-py2.py3-none-any.whl.

File metadata

File hashes

Hashes for dataset_interfaces-0.1.0-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 fd5402215195e3f8c27f7f94e533f59df6f70b5805c5b56ff4b8b45b3ea52e8f
MD5 a4873797febeef4b3865583c770280b4
BLAKE2b-256 6df40ac632d224791b1c3cdf9b5b1aa61bdbb1e69f8053ba26a15056bb44b871

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page