Dataset Interfaces
Project description
Dataset Interfaces
This repository contains the code for our recent work:
Dataset Interfaces: Diagnosing Model Failures Using Controllable Counterfactual Generation
Joshua Vendrow*, Saachi Jain*, Logan Engstrom, Aleksander Madry
Paper: https://arxiv.org/abs/2302.07865
Blog post: https://gradientscience.org/dataset-interfaces/
Getting started
Install using pip, or clone our repository.
pip install dataset-interfaces
Example: For a walkthrough of codebase, check out our example notebook. This notebook shows how to construct a dataset interface for a subset of ImageNet and generate counterfactual examples.
Before running run_textual_inverson
, initialize an 🤗Accelerate environment with:
accelerate config
Constructing a Dataset Interface
Constructing a dataset interface consists or learning a class token for each class in a datset, which can then be included in textual prompts.
To learn a single token, we use the following function:
from dataset_interfaces import run_textual_inversion
embed = run_textual_inversion (
train_path=train_path, # path to directory with training set for a single class
token=token, # text to use for new token, e.g "<plate>"
class_name=class_name, # natrual language class description, e.g., "plate"
)
Once all the class tokens are learned, we can create a custom tokenizer and text encoder with these tokens:
import inference_utils as infer_utils
infer_utils.create_encoder (
embeds=embeds, # list of learned embeddings (from the code block above)
tokens=tokens, # list of token strings
class_names=class_names, # list of natural language class descriptions
encoder_root=encoder_root # path where to store the tokenizer and encoder
)
Generating Counterfactual Examples
We can now generate counterfactual examples by incorporating our learned tokens in textual prompts. The generate
function generates images for a specific class in the dataset (indexed in the order that classes are passed when constructing the encoder). When specifying the text prompt, "" acts as a placeholder for the class token.
from dataset_interfaces import generate
generate (
encoder_root=encoder_root,
c=c, # index of a specific class
prompts="a photo of a <TOKEN> in the grass", # can be a single prompt or a list of prompts
num_samples=10,
random_seed=0 # no seed by default
)
CLIP Metric
To directly evaluate the quality of the generated image, we use CLIP similarity to quantify the presence of the object of interest and desired distribution shift in the image.
We can measure CLIP similarity between a set of generated images and a given caption as follows:
sim_class = infer_utils.clip_similarity(imgs, "a photo of a dog")
sim_shift = infer_utils.clip_similarity(imgs, "a photo in the grass")
ImageNet* Benchmark
Our benchmark for the ImageNet dataset consists of two components: our 1,000 learned class tokens for ImageNet, and the images generated by these tokens in 23 distribution shifts.
ImageNet* Tokens
The 1,000 learned tokens are avaiable on HuggingFace and can be downloaded with:
wget https://huggingface.co/datasets/madrylab/imagenet-star-tokens/resolve/main/tokens.zip
To generate images with these tokens, we first create a text encoder with the tokens, which we use to seamlessly integrate the tokens in text prompts:
token_path = "./tokens". # path to the tokens from HuggingFace
infer_utils.create_imagenet_star_encoder(token_path, encoder_root="./encoder_root_imagenet")
Now, we can generate counterfactual examples of ImageNet from a textual prompt (See the example notebook for a walk-through):
from dataset_interfaces import generate
encoder_root = "./encoder_root_imagenet"
c = 207 # the class for golden retriever
prompt = "a photo of a <TOKEN> wearing a hat"
generate(encoder_root, c, prompt, num_samples=10)
ImageNet* Images
Our benchmark contains images in 23 distribution shifts, with 50k images per shift (50 per class for 1000 classes). These images are also available on HuggingFace. In this repo we also provide masks for each distribution shift indicating which images we filter out with our CLIP metrics, at masks.npy
.
We provide a wrapper on top torchvision.datasets.ImageFolder
to construct a dataset object that filters the images o=un the benchmark using this mask. So, we can make a dataset object for a shift as follows:
from dataset_interfaces import utils
root = "./imagenet-star" # the path where the dataset from HuggingFace
mask_path = "./masks.npy" # the path to the mask file
shift = "in_the_snow" # the distribution shift of interest
ds = utils.ImageNet_Star_Dataset(
root,
shift=shift,
mask_path=mask_path
)
Citation
To cite this paper, please use the following BibTex entry:
@inproceedings{vendrow2023dataset,
title = {Dataset Interfaces: Diagnosing Model Failures Using Controllable Counterfactual Generation},
author = {Joshua Vendrow and Saachi Jain and Logan Engstrom and Aleksander Madry},
booktitle = {ArXiv preprint arXiv:2302.07865},
year = {2023}
}
Maintainers:
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file dataset_interfaces-0.1.0.tar.gz
.
File metadata
- Download URL: dataset_interfaces-0.1.0.tar.gz
- Upload date:
- Size: 41.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 40ca0a8aef1c7153109a706b0e77d9d5d4e782b0640baebe4935189f66b52106 |
|
MD5 | 36f0b4ec08c13da3e919eb6056e33b9b |
|
BLAKE2b-256 | 9d71c9c5ba1c7ad15c0a20c691d1ca825f4ad4f60658da3a9f9e309e31adc5eb |
File details
Details for the file dataset_interfaces-0.1.0-py2.py3-none-any.whl
.
File metadata
- Download URL: dataset_interfaces-0.1.0-py2.py3-none-any.whl
- Upload date:
- Size: 40.7 kB
- Tags: Python 2, Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | fd5402215195e3f8c27f7f94e533f59df6f70b5805c5b56ff4b8b45b3ea52e8f |
|
MD5 | a4873797febeef4b3865583c770280b4 |
|
BLAKE2b-256 | 6df40ac632d224791b1c3cdf9b5b1aa61bdbb1e69f8053ba26a15056bb44b871 |