Skip to main content

Generate Most Exciting Input to explore and understand PyTorch model's behavior by identifying input samples that induce high activation from specific neurons in your model.

Project description

Most Exciting Input

Generate Most Exciting Input to explore and understand PyTorch model's behavior by identifying input samples that induce high activation from specific neurons in your model.

Paper

paper: Input optimization for interpreting neural generative models

Installation

pip install meitorch

Usage

  1. Load the model you want to generate interpretable visualizations for.
model = load_your_model()
  1. Define the operation you want to optimize most exciting inputs for.

Your operation must take a batch of inputs as parameters and return a dictionary of losses. Visualizations will be generated for the input that maximizes the loss named "objective".

import torch

def operation(inputs):
    outputs = model(inputs)
    activation = outputs[:, 0]
    activation = torch.mean(activation, dim=0)
    loss = -activation
    losses = dict(
        objective=loss,
        activation=activation,
    )
    return losses

You can also define more complex operations that include multiple losses. Adding other losses to the dictionary will enable you to plot them after the optimization.

def operation(inputs):
    outputs = model(inputs)
    activation = outputs[:, 0]
    activation = torch.mean(activation, dim=0)
    model_losses = compute_loss(inputs, outputs)
    regularization = losses["elbo"] * 0.1
    loss = -activation + regularization
    losses = dict(
        objective=loss,
        activation=activation,
        elbo_regularization=regularization,
    )
   return losses
  1. Create a MEI object with your operation and the input shape of your model.
from meitorch.mei import MEI

mei = MEI(operation=operation(), shape=(1, 40, 40), device=device)
  1. Define a configuration for the optimization and generate Most Exciting Inputs. There are minor differences between the configuration for different optimization schemes, more on configurations below.

Generate pixel-wise MEI

pixel_mei_config = dict(/* your config here */)
result = mei.generate_pixel_mei(config=pixel_mei_config)

Generate variational MEI

variational_mei_config = dict(/* your config here */)
result = mei.generate_variational_mei(config=variational_mei_config)

Generate transformation MEI

transformation_mei_config = dict(/* your config here */)
result = mei.generate_transformation_mei(config=transformation_mei_config)
  1. Analyze the results Access the generated images and the losses from the result object.

Plot the loss curves and the visualizations

result.plot_losses(show=False, save_path=None, ranges=None)
result.plot_image_and_losses(self, save_path=None, ranges=None)

Plot spatial frequency spectrum of the generated images

result.plot_spatial_frequency_spectrum()

Further analysis You can further analyze the results with the meitorch.analyze module.

from meitorch.analyze import Analyze

Configurations

For all configurations, you can use a schedule instead of a constant value for any parameter. A schedule is a function that takes the current iteration as input and returns the value for that iteration. You can access the schedule class in meitorch.tools.schedules.

from meitorch.tools.schedules import LinearSchedule

schedule = LinearSchedule(start=0.1, end=0.01)

Available schedules are:

- LinearSchedule(start, end)
- OctaveSchedule(values)
- RandomSchedule(minimum, maximum)

Pixel-wise MEI configuration example

image_mei_config = dict(
    iter_n=2,         # number of optimization steps
    n_samples=1,      # number of samples per batch
    save_every=1,     # save copy of image every n iterations
    bias=0,           # bias of the distribution the image is sampled from
    scale=1,          # scaling of the distribution the image is sampled from
    diverse=False,    # whether to use diverse sampling
    diverse_params=dict(
        div_metric='euclidean', # distance metric for diversity (euclidean, cosine, correlation)
        div_linkage='minimum',  # linkage criterion for diversity (minimum, average)
        div_weight=1.1,         # weight of diversity loss
    ),

    #pre-step transformations
    scaler=1.01,          # scaling of the image before each step
    jitter=3,             # size of translational jittering before each step

    #normalization/clipping
    train_norm=1,        # norm adjustment during step
    norm=1,              # norm adjustment after step

    #optmizer
    optimizer="rmsprop",    # optimizer (sgd, mei, rmsprop, adam)
    optimizer_params=dict(
        lr=0.03,            # learning rate
        weight_decay=1e-6,  # weight decay
    ),

    #preconditioning in the gradient
    precond=0.3,            # strength of gradient preconditioning filter falloff (float or schedule)

    #denoiser after each step
    blur='gaussian',        # denoiser type (gaussian, tv, bilateral)
    blur_params=dict(
        #gaussian
        kernel_size=3,
        sigma=LinearSchedule(0.1, 0.01)
        
        #tv
        #regularization_scaler=1e-7,
        #lr=0.0001,
        #num_iters=5,
        
        #bilateral
        #kernel_size=3,
        #sigma_color=LinearSchedule(1, 0.01),
        #sigma_spatial=LinearSchedule(0.25, 0.01),
    ),
)

Variational MEI configuration example

var_mei_config = dict(
    iter_n=1,              # number of optimization steps
    save_every=100,        # save image every n iterations
    bias=0,                # bias of the distribution the image is sampled from
    scale=1,              # scaling of the distribution the image is sampled from

    #transformations
    scaler=RandomSchedule(1, 1.025),  # scaling of the image (float or schedule)
    jitter=None,                      # size of translational jittering

    #optmizer
    optimizer="rmsprop",        # optimizer (sgd, mei, rmsprop, adam)
    optimizer_params=dict(   
        lr=0.04,                # learning rate
        weight_decay=1e-7,      # weight decay
    ),

    #preconditioning
    precond=0.4,            # strength of gradient preconditioning filter

    #variational
    distribution='normal',      # distribution of the MEI (normal, laplace)
    n_samples_per_batch=(128,), # number of samples per batch (tuple)
    fixed_stddev=0.4,           # fixed stddev of the distribution, None for learned stddev
)

Transformation MEI configuration example

For the transformation MEI, you need to define a transformation operation that takes an image as input and returns a transformed image. Any backpropagatable operation can be used as a transformation. In the example below, we use a generative convolutional network, which is defined in meitorch.tools.transformations.

tranformation_mei_config = dict(
        iter_n=150,          # number of optimization steps
        save_every=1,        # save image every n iterations
        bias=0,              # bias of the distribution the image is sampled from
        scale=1,             # scaling of the distribution the image is sampled from
        n_samples=128,       # number of samples per batch

        #transformations before each step
        scaler=None,            # scaling of the image (float or schedule)
        jitter=None,            # size of translational jittering

        #normalization
        train_norm=None,        # norm adjustment during step

        #optmizer
        optimizer="mei",        # optimizer (sgd, mei, rmsprop, adam)
        optimizer_params=dict
        (
            lr=0.02,            # learning rate
            weight_decay=1e-5,  # weight decay
        ),
    
        #preconditioning
        precond=0.4,            # strength of gradient preconditioning filter

        # transformation operation
        transform = GenerativeConvNet(hidden_sizes=[1], fixed_stddev=0.6, kernel_size=9,  activation=torch.nn.ReLU(), activate_output=False, shape=(1, 40, 40))
    )

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

meitorch-0.223.tar.gz (32.7 kB view details)

Uploaded Source

Built Distribution

meitorch-0.223-py3-none-any.whl (40.7 kB view details)

Uploaded Python 3

File details

Details for the file meitorch-0.223.tar.gz.

File metadata

  • Download URL: meitorch-0.223.tar.gz
  • Upload date:
  • Size: 32.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.16

File hashes

Hashes for meitorch-0.223.tar.gz
Algorithm Hash digest
SHA256 7092100264b09a3688ae22712bf987940129e7e779bc527235e5122c4304c768
MD5 752e3008af48ccccce5d88d053c6bf15
BLAKE2b-256 1fde24d6a653a64b6dacbea76c010019f3e65183a800da4e51fe81f15bf384d2

See more details on using hashes here.

File details

Details for the file meitorch-0.223-py3-none-any.whl.

File metadata

  • Download URL: meitorch-0.223-py3-none-any.whl
  • Upload date:
  • Size: 40.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.16

File hashes

Hashes for meitorch-0.223-py3-none-any.whl
Algorithm Hash digest
SHA256 13d59e268426613115c089a84361aaa6c69a240fcb15c329afb67ba067bfe4dd
MD5 7350f6e77576e8539970e39bd3af4226
BLAKE2b-256 46d109351efeb8b1e8bf92d423b8a74f32aa198aa672ba20492f625b0a81d1c5

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page