Skip to main content

Package to compute Compositional Explanations both in their optimal and beam format.

Project description

Compositional Explanations Package

This package provides functions to compute compositional explanations for a given bitmap and a set of boolean masks. The package includes two main methods for computing explanations: optimal and beam search. Additionally, it provides metrics to evaluate the quality of the explanations.

The package assumes you are able to provide the following inputs:

  • bitmaps: A boolean tensor representing the bitmap to be explained for a given neuron. The bitmap represents the activation of the neuron for a set of inputs. A cell is 1 if the neuron is activated for that feature or whole input, and 0 otherwise.
  • masks: A list of boolean tensors (or csr sparse matrices from the scipy library) representing the masks for each feature. Each mask represents whether a given concept/feature is annotated in the given position or samples. A cell is 1 if the feature is present in that position or sample, and 0 otherwise.
  • disjoint_info: A boolean tensor representing the disjointness of the masks. The tensor is of shape (num_masks, num_masks) and indicates whether two masks are disjoint (i.e., they do not overlap). A cell is 1 if the two masks are disjoint, and 0 otherwise.
  • length: An integer representing the maximum length of the explanation formula. The formula is a combination of masks that best explains the bitmap.
  • beam_size: An integer representing the beam size for the beam search method. This parameter controls the number of candidate explanations to consider at each step of the search.
  • device: A string representing the device to run the computations on. Highly suggested to use a GPU for faster computations.
  • cache_dir: A string representing the directory to cache heuristic information for the given masks. This is useful for large datasets where computing the heuristics can be time-consuming. Default is None, which means no caching will be used.

The functions provided in this package are:

  • optimal.compute_optimal_explanations(bitmaps, masks, disjoint_info, length, device, cache_dir): Computes the optimal explanation formula for the given bitmap and masks. This method can be slow (~1 hour per neuron) for high complexity scenarios.
  • beam.compute_beam_explanations(bitmaps, masks, disjoint_info, length, beam_size, device, cache_dir): Computes the explanation formula for the given bitmap and masks using a beam search method. This method is faster than the optimal method but cannot guarantee the optimality of the explanation. It is recommended to use this method for large datasets or when a faster computation is needed.
  • metrics.compute_iou_from_masks(formula, masks, bitmaps): Computes the Intersection over Union (IoU) metric for the given explanation formula, masks, and bitmap. The IoU metric measures the overlap between the explanation masks and the bitmap, providing a measure of how well the explanation captures the bitmap's activation.

Note that all the functions take named arguments.

Example usage of the compositional_explanations package

from compositional_explanations import beam, optimal, metrics
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create 5 boolean random masks of shape (4, 4)
masks = [torch.randint(0, 2, (4, 4), dtype=torch.bool) for _ in range(5)]

concept_index = 0
for mask in masks:
    print(f"Mask {concept_index}:")
    print(mask)
    concept_index += 1

# Create a random bitmap of shape (4, 4)
bitmap = torch.randint(0, 2, (4, 4), dtype=torch.bool)

print(f"Bitmap:")
print(bitmap)

# Compute disjoint matrix (inefficent, we suggest to compute them directly from the annotations)
disjoint_matrix = torch.ones((len(masks), len(masks)), dtype=torch.bool)
for i in range(len(masks)):
    for j in range(i + 1, len(masks)):
        disjoint_matrix[i, j] = not torch.any(masks[i] & masks[j])
        disjoint_matrix[j, i] = disjoint_matrix[i, j]

print("Disjoint Matrix:")
print(disjoint_matrix)
# Compute optimal explanations
best_formula_optimal = optimal.compute_optimal_explanations(
    bitmaps=bitmap, masks=masks, disjoint_info=disjoint_matrix, length=3, device=device, cache_dir=None
)
optimal_iou = metrics.compute_iou_from_masks(formula=best_formula_optimal, masks=masks, bitmaps=bitmap)

print("Optimal Explanation:")
print("Best formula:", best_formula_optimal)
print("Best IoU:", optimal_iou)

best_formula_beam = beam.compute_beam_explanations(
    bitmaps=bitmap, masks=masks, disjoint_info=disjoint_matrix, length=3, beam_size=5, device=device, cache_dir=None
)   
beam_iou = metrics.compute_iou_from_masks(formula=best_formula_beam, masks=masks, bitmaps=bitmap)
print("Beam Explanation:")
print("Best formula:", best_formula_beam)
print("Best IoU:", beam_iou)
print()
print("Beam Explanation is optimal:", beam_iou == optimal_iou)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

compositional_explanations-0.0.3.tar.gz (35.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

compositional_explanations-0.0.3-py3-none-any.whl (39.1 kB view details)

Uploaded Python 3

File details

Details for the file compositional_explanations-0.0.3.tar.gz.

File metadata

File hashes

Hashes for compositional_explanations-0.0.3.tar.gz
Algorithm Hash digest
SHA256 eb5cb2feff08cf9b551732174861f22d712190842eba53910d0743e8df889f47
MD5 d29bebe813ace039d04b11122ad5d42a
BLAKE2b-256 54a1d16f1151922a4e2d6fcd2b03a7773f8128a368fbbf86829f6e889cea3549

See more details on using hashes here.

File details

Details for the file compositional_explanations-0.0.3-py3-none-any.whl.

File metadata

File hashes

Hashes for compositional_explanations-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 b5428a53f6bd2206a406f5d845408fce0638f93e0af970d9859cf154e4933b39
MD5 4c786f4509bc457e8020adf04511945e
BLAKE2b-256 66ee5809f26027b689e4e268cb29695cc55523f575a1147cd1e66b64725dd025

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page