diffusers-interpret: model explainability for 🤗 Diffusers
Project description
Diffusers-Interpret 🤗🧨🕵️♀️
diffusers-interpret
is a model explainability tool built on top of 🤗 Diffusers.
Installation
Install directly from PyPI:
pip install diffusers-interpret
Usage
Let's see how we can interpret the new 🎨🎨🎨 Stable Diffusion!
# make sure you're logged in with `huggingface-cli login`
import torch
from contextlib import nullcontext
from diffusers import StableDiffusionPipeline
from diffusers_interpret import StableDiffusionPipelineExplainer
device = 'cuda' if torch.cuda.is_available() else 'cpu'
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
# FP16 is not working for 'cpu'
revision='fp16' if device != 'cpu' else None,
torch_dtype=torch.float16 if device != 'cpu' else None
).to(device)
# pass pipeline to the explainer class
explainer = StableDiffusionPipelineExplainer(pipe)
# generate an image with `explainer`
prompt = "A cute corgi with the Eiffel Tower in the background"
generator = torch.Generator(device).manual_seed(2022)
with torch.autocast('cuda') if device == 'cuda' else nullcontext():
output = explainer(
prompt,
num_inference_steps=15,
generator=generator
)
To check the final generated image:
output['sample']
You can also check the image that the diffusion process generated in the end of each step.
For example, to see the image from step 10:
output['all_samples_during_generation'][10]
To check how a token in the input prompt
influenced the generation, you can check the token attribution scores:
>>> output['token_attributions'] # (token, attribution)
[('a', 1063.0526),
('cute', 415.62888),
('corgi', 6430.694),
('with', 1874.0208),
('the', 1223.2847),
('eiffel', 4756.4556),
('tower', 4490.699),
('in', 2463.1294),
('the', 655.4624),
('background', 3997.9395)]
Or their computed normalized version, in percentage:
>>> output['normalized_token_attributions'] # (token, attribution_percentage)
[('a', 3.884),
('cute', 1.519),
('corgi', 23.495),
('with', 6.847),
('the', 4.469),
('eiffel', 17.378),
('tower', 16.407),
('in', 8.999),
('the', 2.395),
('background', 14.607)]
diffusers-interpret
also computes these token attributions for generating a particular part of the image.
To do that, call explainer
with a particular 2D bounding box defined in explanation_2d_bounding_box
:
generator = torch.Generator(device).manual_seed(2022) # re-use generator
with torch.autocast('cuda') if device == 'cuda' else nullcontext():
output = explainer(
prompt,
num_inference_steps=15,
generator=generator,
explanation_2d_bounding_box=((70, 180), (400, 435)), # (upper left corner, bottom right corner)
)
output['sample']
The generated image now has a red bounding box to indicate the region of the image that is being explained.
The token attributions are now computed only for the area specified in the image.
>>> output['normalized_token_attributions'] # (token, attribution_percentage)
[('a', 1.891),
('cute', 1.344),
('corgi', 23.115),
('with', 11.995),
('the', 7.981),
('eiffel', 5.162),
('tower', 11.603),
('in', 11.99),
('the', 1.87),
('background', 23.05)]
Check other functionalities and more implementation examples in here.
Future Development
- Add interactive display of all the images that were generated in the diffusion process
- Add interactive bounding-box and token attributions visualization
- Add unit tests
- Add example for
diffusers_interpret.LDMTextToImagePipelineExplainer
- Do not require another generation every time the
explanation_2d_bounding_box
argument is changed - Add more explainability methods
Contributing
Feel free to open an Issue or create a Pull Request and let's get started 🚀
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Hashes for diffusers-interpret-0.0.1.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 41cd7288420136be4e86263e6fea760d8931c733126e636c52e1e8becbe25bf7 |
|
MD5 | 95f3cbf3e731dd39a311d387ac76f398 |
|
BLAKE2b-256 | 334a7f2e8dcd3d4e1d097fed388f6cc191714cac1593f5de6ca170d85a82ea1f |