The implementation of the RISE algorithm for the Captum framework
Project description
Captum RISE API Design
Background
RISE (Randomized Input Sampling for Explanation) is a perturbation based approach to compute attribution. RISE uses a Monte-Carlo approximation algorithm to detect the sensitivity of an output value with respect to input features. RISE is a post-hoc, local and model-agnostic interpretability method that can be used any type of image-like classification network to explain its decisions. RISE generates a pixel importance map (saliency), indicating how salient is each pixel with respect to the model's predictions. It estimates the sensitivity of each input feature by sampling n_masks
random occlusion masks, and computing the network output for each of the correspondingly occluded input images. Each mask is assigned a score based on the output of the model. Afterwards, masks are averaged, using the score as weight.
To sample occlusion masks, RISE assumes a strong spatial structure in the feature space, so that features that are close to each other are more likely to be correlated. It does not assume access to the parameters of the model, features or gradients and therefore does not require accessing the internals of the network architecture. When RISE is applied to images, it allows visualizing the regions of the image that were important for the prediction of a certain class.
Requires:
- $I$: input of size $H \times W$ in the space of images $\mathcal{I} = \lbrace I | I : \Lambda \rightarrow \mathbb{R}^{d} \rbrace$ with $\Lambda = \lbrace 1,\dots,H \rbrace \times \lbrace 1,\dots,W \rbrace$.
- $M: \Lambda \rightarrow [0,1]$: $N$ random binary bilinear upsampled masks.
- $f: \mathcal{I} \rightarrow \mathbb{R}$: black-box model that for a given input produce a scalar confidence score (for example, the probability of a single class).
Mathematical formula:
$$ \begin{aligned} S_{I,f}(\lambda) \stackrel{MC}{\approx} \frac{1}{\mathbb{E}[M] \cdot N} \sum_{i=1}^{N} f(I \odot M_{i}) \cdot M_{i}(\lambda) \end{aligned} $$
$S_{I,f}$ is the importance map, explaining the decision of model $f$ on input image $I$, normalized to the expectation of $M$. The importance of a pixel $\lambda \in \Lambda$, is the score over all possible masks $M$ conditioned on the event that pixel $\lambda$ is observed, i.e., $M(\lambda) = 1$. Finally, $f(I \odot M_{i})$ is a random variable that represents the model's output for the target class by performing the inference over the masked input, where $\odot$ denotes the Hadamard product (element-wise multiplication).
Pseudocode:
# Algorithm 1: Mask Generation
# Require: input_shape = (HxW), initial_mask_shape = (hxw)
# Ensure: upsample_shape = (h+1)*CH x (w+1)*CW (where CHxCW = ceil(H/h) x ceil(W/w))
Repeat N times:
Create a binary mask_i of size initial_mask_shape
Convert Binary Mask to Smooth Mask:
Upsample mask_i via bilinear interpolation to size upsampled_shape
Crop upsampled mask_i randomly with a window of size input_shape
Store the resulting mask_i inside the masks array or list
return masks
# Algorithm 2: Importance Map Generation
Repeat N times:
masked_input i = mask input with mask_i
Compute the model prediction for the masked input (weight_i of mask_i)
Add to the heatmap the mask_i multiplied by its weight_i
Normalize the final heatmap by the number n_masks of masks generated
return heatmap
More details regarding the method can be found in the original paper and in the RISE implementation.
Design Considerations:
- RISE works for generic inputs shapes provided by the user.
- The return value is the importance map.
- This implementation uses the extendes the Feature Ablation capabilities of Captum with the same API, making the code simple to mantain and learn.
RISE
The RISE class is a generic class which allows training any interpretable model based on by sub-sampling the input image via random masks and recording its response to each of the masked images. The constructor takes only the model forward function. The RISE class makes certain assumptions to the generic FeatureAblation class in order to match the structure of other attribution methods and allow easier switching between methods.
Constructor:
RISE(forward_func: Callable)
Argument Descriptions:
forward_func
-torch.nn.Module
corresponding to the forward function of the model or any modification of it, for which attributions are desired. This is consistent with all other attribution constructors in Captum.
attribute:
attribute(input_set: TensorOrTupleOfTensorsGeneric,
n_masks: int,
initial_mask_shapes: TensorOrTupleOfTensorsGeneric,
baselines: BaselineType = None,
target: TargetType = None,
additional_forward_args: Any = None,
show_progress: bool = False,
) -> TensorOrTupleOfTensorsGeneric:
Argument Descriptions:
These arguments follow standard definitions of existing Captum methods. inputs
is the input for which RISE attributions are computed. The type of it is Tensor or tuple[Tensor, ...]
because if forward_func
takes a single tensor as input, a single input tensor should be provided. And also, if forward_func
takes multiple tensors as input, a tuple of the input tensors should be provided. It is assumed that for all given input tensors, dimension 0 corresponds to the number of examples (aka batch size), and if multiple input tensors are provided, the examples must be aligned appropriately. initial_mask_shapes
is the initial mask shapes for each input of the model. The type of it is Tensor or tuple[Tensor, ...]
. target
is the target class for each input of the model.
A custom class named MaskSetConfig
is in charge of computing the generation of the set of masks for the preconfigured input shapes, types and initial mask shapes. By default, assumes each input is of shape [B,C,D1,D2,D3]
, where B
and C
are the batch and channel dimensions which are ignored, and the final mask sizes are extracted as [D1,D2,D3]
. Dimensions D2
and D3
are optional. Because of pytorch's interpolate
limitations, this only supports 5D inputs (3D masks) at most.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file captum-rise-1.0.tar.gz
.
File metadata
- Download URL: captum-rise-1.0.tar.gz
- Upload date:
- Size: 8.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2c61754747090eea1c15edda4238b09d984ca5ab199508d0fa45330b032d03ff |
|
MD5 | effe879d5332deb52286c439472590b3 |
|
BLAKE2b-256 | a92046e7deafddebd9b763cc5e45f43daa952f77a3bc4b0a0d7794b5f0884fff |