Generate Most Exciting Input to explore and understand PyTorch model's behavior by identifying input samples that induce high activation from specific neurons in your model.
Project description
Most Exciting Input
Generate Most Exciting Input to explore and understand PyTorch model's behavior by identifying input samples that induce high activation from specific neurons in your model.
Paper
paper: Input optimization for interpreting neural generative models
Installation
pip install meitorch
Usage
- Load the model you want to generate interpretable visualizations for.
model = load_your_model()
- Define the operation you want to optimize most exciting inputs for.
Your operation must take a batch of inputs as parameters and return a dictionary of losses. Visualizations will be generated for the input that maximizes the loss named "objective".
import torch
def operation(inputs):
outputs = model(inputs)
activation = outputs[:, 0]
activation = torch.mean(activation, dim=0)
loss = -activation
losses = dict(
objective=loss,
activation=activation,
)
return losses
You can also define more complex operations that include multiple losses. Adding other losses to the dictionary will enable you to plot them after the optimization.
def operation(inputs):
outputs = model(inputs)
activation = outputs[:, 0]
activation = torch.mean(activation, dim=0)
model_losses = compute_loss(inputs, outputs)
regularization = losses["elbo"] * 0.1
loss = -activation + regularization
losses = dict(
objective=loss,
activation=activation,
elbo_regularization=regularization,
)
return losses
- Create a MEI object with your operation and the input shape of your model.
from meitorch.mei import MEI
mei = MEI(operation=operation(), shape=(1, 40, 40), device=device)
- Define a configuration for the optimization and generate Most Exciting Inputs. There are minor differences between the configuration for different optimization schemes, more on configurations below.
Generate pixel-wise MEI
pixel_mei_config = dict(/* your config here */)
result = mei.generate_pixel_mei(config=pixel_mei_config)
Generate variational MEI
variational_mei_config = dict(/* your config here */)
result = mei.generate_variational_mei(config=variational_mei_config)
Generate transformation MEI
transformation_mei_config = dict(/* your config here */)
result = mei.generate_transformation_mei(config=transformation_mei_config)
- Analyze the results Access the generated images and the losses from the result object.
Plot the loss curves and the visualizations
result.plot_losses(show=False, save_path=None, ranges=None)
result.plot_image_and_losses(self, save_path=None, ranges=None)
Plot spatial frequency spectrum of the generated images
result.plot_spatial_frequency_spectrum()
Further analysis You can further analyze the results with the meitorch.analyze module.
from meitorch.analyze import Analyze
Configurations
For all configurations, you can use a schedule instead of a constant value for any parameter. A schedule is a function that takes the current iteration as input and returns the value for that iteration. You can access the schedule class in meitorch.tools.schedules.
from meitorch.tools.schedules import LinearSchedule
schedule = LinearSchedule(start=0.1, end=0.01)
Available schedules are:
- LinearSchedule(start, end)
- OctaveSchedule(values)
- RandomSchedule(minimum, maximum)
Pixel-wise MEI configuration example
image_mei_config = dict(
iter_n=2, # number of optimization steps
n_samples=1, # number of samples per batch
save_every=1, # save copy of image every n iterations
bias=0, # bias of the distribution the image is sampled from
scale=1, # scaling of the distribution the image is sampled from
diverse=False, # whether to use diverse sampling
diverse_params=dict(
div_metric='euclidean', # distance metric for diversity (euclidean, cosine, correlation)
div_linkage='minimum', # linkage criterion for diversity (minimum, average)
div_weight=1.1, # weight of diversity loss
),
#pre-step transformations
scaler=1.01, # scaling of the image before each step
jitter=3, # size of translational jittering before each step
#normalization/clipping
train_norm=1, # norm adjustment during step
norm=1, # norm adjustment after step
#optmizer
optimizer="rmsprop", # optimizer (sgd, mei, rmsprop, adam)
optimizer_params=dict(
lr=0.03, # learning rate
weight_decay=1e-6, # weight decay
),
#preconditioning in the gradient
precond=0.3, # strength of gradient preconditioning filter falloff (float or schedule)
#denoiser after each step
blur='gaussian', # denoiser type (gaussian, tv, bilateral)
blur_params=dict(
#gaussian
kernel_size=3,
sigma=LinearSchedule(0.1, 0.01)
#tv
#regularization_scaler=1e-7,
#lr=0.0001,
#num_iters=5,
#bilateral
#kernel_size=3,
#sigma_color=LinearSchedule(1, 0.01),
#sigma_spatial=LinearSchedule(0.25, 0.01),
),
)
Variational MEI configuration example
var_mei_config = dict(
iter_n=1, # number of optimization steps
save_every=100, # save image every n iterations
bias=0, # bias of the distribution the image is sampled from
scale=1, # scaling of the distribution the image is sampled from
#transformations
scaler=RandomSchedule(1, 1.025), # scaling of the image (float or schedule)
jitter=None, # size of translational jittering
#optmizer
optimizer="rmsprop", # optimizer (sgd, mei, rmsprop, adam)
optimizer_params=dict(
lr=0.04, # learning rate
weight_decay=1e-7, # weight decay
),
#preconditioning
precond=0.4, # strength of gradient preconditioning filter
#variational
distribution='normal', # distribution of the MEI (normal, laplace)
n_samples_per_batch=(128,), # number of samples per batch (tuple)
fixed_stddev=0.4, # fixed stddev of the distribution, None for learned stddev
)
Transformation MEI configuration example
For the transformation MEI, you need to define a transformation operation that takes an image as input and returns a transformed image. Any backpropagatable operation can be used as a transformation. In the example below, we use a generative convolutional network, which is defined in meitorch.tools.transformations.
tranformation_mei_config = dict(
iter_n=150, # number of optimization steps
save_every=1, # save image every n iterations
bias=0, # bias of the distribution the image is sampled from
scale=1, # scaling of the distribution the image is sampled from
n_samples=128, # number of samples per batch
#transformations before each step
scaler=None, # scaling of the image (float or schedule)
jitter=None, # size of translational jittering
#normalization
train_norm=None, # norm adjustment during step
#optmizer
optimizer="mei", # optimizer (sgd, mei, rmsprop, adam)
optimizer_params=dict
(
lr=0.02, # learning rate
weight_decay=1e-5, # weight decay
),
#preconditioning
precond=0.4, # strength of gradient preconditioning filter
# transformation operation
transform = GenerativeConvNet(hidden_sizes=[1], fixed_stddev=0.6, kernel_size=9, activation=torch.nn.ReLU(), activate_output=False, shape=(1, 40, 40))
)
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file meitorch-0.223.tar.gz
.
File metadata
- Download URL: meitorch-0.223.tar.gz
- Upload date:
- Size: 32.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.8.16
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7092100264b09a3688ae22712bf987940129e7e779bc527235e5122c4304c768 |
|
MD5 | 752e3008af48ccccce5d88d053c6bf15 |
|
BLAKE2b-256 | 1fde24d6a653a64b6dacbea76c010019f3e65183a800da4e51fe81f15bf384d2 |
File details
Details for the file meitorch-0.223-py3-none-any.whl
.
File metadata
- Download URL: meitorch-0.223-py3-none-any.whl
- Upload date:
- Size: 40.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.8.16
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 13d59e268426613115c089a84361aaa6c69a240fcb15c329afb67ba067bfe4dd |
|
MD5 | 7350f6e77576e8539970e39bd3af4226 |
|
BLAKE2b-256 | 46d109351efeb8b1e8bf92d423b8a74f32aa198aa672ba20492f625b0a81d1c5 |