diffusers-interpret: model explainability for 🤗 Diffusers
Project description
Diffusers-Interpret 🤗🧨🕵️♀️
diffusers-interpret
is a model explainability tool built on top of 🤗 Diffusers
Installation
Install directly from PyPI:
pip install diffusers-interpret
Usage
Let's see how we can interpret the new 🎨🎨🎨 Stable Diffusion!
# make sure you're logged in with `huggingface-cli login`
import torch
from contextlib import nullcontext
from diffusers import StableDiffusionPipeline
from diffusers_interpret import StableDiffusionPipelineExplainer
device = 'cuda' if torch.cuda.is_available() else 'cpu'
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
use_auth_token=True,
# FP16 is not working for 'cpu'
revision='fp16' if device != 'cpu' else None,
torch_dtype=torch.float16 if device != 'cpu' else None
).to(device)
# pass pipeline to the explainer class
explainer = StableDiffusionPipelineExplainer(pipe)
# generate an image with `explainer`
prompt = "A cute corgi with the Eiffel Tower in the background"
generator = torch.Generator(device).manual_seed(2022)
with torch.autocast('cuda') if device == 'cuda' else nullcontext():
output = explainer(
prompt,
num_inference_steps=15,
generator=generator
)
To check the final generated image:
output['sample']
You can also check all the images that the diffusion process generated at the end of each step.
To check how a token in the input prompt
influenced the generation, you can check the token attribution scores:
>>> output['token_attributions'] # (token, attribution)
[('a', 1063.0526),
('cute', 415.62888),
('corgi', 6430.694),
('with', 1874.0208),
('the', 1223.2847),
('eiffel', 4756.4556),
('tower', 4490.699),
('in', 2463.1294),
('the', 655.4624),
('background', 3997.9395)]
Or their computed normalized version, in percentage:
>>> output['normalized_token_attributions'] # (token, attribution_percentage)
[('a', 3.884),
('cute', 1.519),
('corgi', 23.495),
('with', 6.847),
('the', 4.469),
('eiffel', 17.378),
('tower', 16.407),
('in', 8.999),
('the', 2.395),
('background', 14.607)]
diffusers-interpret
also computes these token attributions for generating a particular part of the image.
To do that, call explainer
with a particular 2D bounding box defined in explanation_2d_bounding_box
:
generator = torch.Generator(device).manual_seed(2022) # re-use generator
with torch.autocast('cuda') if device == 'cuda' else nullcontext():
output = explainer(
prompt,
num_inference_steps=15,
generator=generator,
explanation_2d_bounding_box=((70, 180), (400, 435)), # (upper left corner, bottom right corner)
)
output['sample']
The generated image now has a red bounding box to indicate the region of the image that is being explained.
The token attributions are now computed only for the area specified in the image.
>>> output['normalized_token_attributions'] # (token, attribution_percentage)
[('a', 1.891),
('cute', 1.344),
('corgi', 23.115),
('with', 11.995),
('the', 7.981),
('eiffel', 5.162),
('tower', 11.603),
('in', 11.99),
('the', 1.87),
('background', 23.05)]
Check other functionalities and more implementation examples in here.
Future Development
-
Add interactive display of all the images that were generated in the diffusion process - Add interactive bounding-box and token attributions visualization
- Add unit tests
- Add example for
diffusers_interpret.LDMTextToImagePipelineExplainer
- Do not require another generation every time the
explanation_2d_bounding_box
argument is changed - Add more explainability methods
Contributing
Feel free to open an Issue or create a Pull Request and let's get started 🚀
Credits
A special thanks to @andrewizbatista for creating a great image slider to show all the generated images during diffusion! 💪
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Hashes for diffusers-interpret-0.2.0.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | b2b475d376aa6d13e82a7680391a13463cecc2dd29a86270648e463f2d09ca3b |
|
MD5 | 68c2a68a7546e81c8fae22c600a84f35 |
|
BLAKE2b-256 | ae9371b38f766c739fc87c304aac0291c73e77b0e17e7fd72d1fae910716eb46 |