diffusers-interpret: model explainability for 🤗 Diffusers
Project description
Diffusers-Interpret 🤗🧨🕵️♀️
diffusers-interpret is a model explainability tool built on top of 🤗 Diffusers.
Installation
Install directly from PyPI:
pip install diffusers-interpret
Usage
Let's see how we can interpret the new 🎨🎨🎨 Stable Diffusion!
# make sure you're logged in with `huggingface-cli login`
import torch
from contextlib import nullcontext
from diffusers import StableDiffusionPipeline
from diffusers_interpret import StableDiffusionPipelineExplainer
device = 'cuda' if torch.cuda.is_available() else 'cpu'
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
# FP16 is not working for 'cpu'
revision='fp16' if device != 'cpu' else None,
torch_dtype=torch.float16 if device != 'cpu' else None
).to(device)
# pass pipeline to the explainer class
explainer = StableDiffusionPipelineExplainer(pipe)
# generate an image with `explainer`
prompt = "A cute corgi with the Eiffel Tower in the background"
generator = torch.Generator(device).manual_seed(2022)
with torch.autocast('cuda') if device == 'cuda' else nullcontext():
output = explainer(
prompt,
num_inference_steps=15,
generator=generator
)
To check the final generated image:
output['sample']
You can also check the image that the diffusion process generated in the end of each step.
For example, to see the image from step 10:
output['all_samples_during_generation'][10]
To check how a token in the input prompt influenced the generation, you can check the token attribution scores:
>>> output['token_attributions'] # (token, attribution)
[('a', 1063.0526),
('cute', 415.62888),
('corgi', 6430.694),
('with', 1874.0208),
('the', 1223.2847),
('eiffel', 4756.4556),
('tower', 4490.699),
('in', 2463.1294),
('the', 655.4624),
('background', 3997.9395)]
Or their computed normalized version, in percentage:
>>> output['normalized_token_attributions'] # (token, attribution_percentage)
[('a', 3.884),
('cute', 1.519),
('corgi', 23.495),
('with', 6.847),
('the', 4.469),
('eiffel', 17.378),
('tower', 16.407),
('in', 8.999),
('the', 2.395),
('background', 14.607)]
diffusers-interpret also computes these token attributions for generating a particular part of the image.
To do that, call explainer with a particular 2D bounding box defined in explanation_2d_bounding_box:
generator = torch.Generator(device).manual_seed(2022) # re-use generator
with torch.autocast('cuda') if device == 'cuda' else nullcontext():
output = explainer(
prompt,
num_inference_steps=15,
generator=generator,
explanation_2d_bounding_box=((70, 180), (400, 435)), # (upper left corner, bottom right corner)
)
output['sample']
The generated image now has a red bounding box to indicate the region of the image that is being explained.
The token attributions are now computed only for the area specified in the image.
>>> output['normalized_token_attributions'] # (token, attribution_percentage)
[('a', 1.891),
('cute', 1.344),
('corgi', 23.115),
('with', 11.995),
('the', 7.981),
('eiffel', 5.162),
('tower', 11.603),
('in', 11.99),
('the', 1.87),
('background', 23.05)]
Check other functionalities and more implementation examples in here.
Future Development
- Add interactive display of all the images that were generated in the diffusion process
- Add interactive bounding-box and token attributions visualization
- Add unit tests
- Add example for
diffusers_interpret.LDMTextToImagePipelineExplainer - Do not require another generation every time the
explanation_2d_bounding_boxargument is changed - Add more explainability methods
Contributing
Feel free to open an Issue or create a Pull Request and let's get started 🚀
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file diffusers-interpret-0.0.1.tar.gz.
File metadata
- Download URL: diffusers-interpret-0.0.1.tar.gz
- Upload date:
- Size: 8.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.10.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
41cd7288420136be4e86263e6fea760d8931c733126e636c52e1e8becbe25bf7
|
|
| MD5 |
95f3cbf3e731dd39a311d387ac76f398
|
|
| BLAKE2b-256 |
334a7f2e8dcd3d4e1d097fed388f6cc191714cac1593f5de6ca170d85a82ea1f
|