Skip to main content

Library that provides metrics to assess representation quality

Project description

Documentation Status PyPI

Representation quality metrics for pretrained deep models! ⭐


Reptrix


Reptrix

About

Reptrix, short for Representation Metrics, is a PyTorch library designed to simplify the evaluation of representation quality metrics in pretrained deep neural networks. Reptrix offers a suite of recently proposed metrics, predominanty in the vision self-supervised learning literature, that are essential for researchers and engineers focusing on design, deployment, evaluation and interpretability of deep neural networks in computer vision settings.

Key Features:

  • Comprehensive Metric Suite: Includes a variety of metrics to assess various aspects of representation quality, that are indicative of capacity, robustness and downstream task performance.
  • PyTorch Integration: Seamlessly integrates with existing PyTorch models and workflows, allowing for straightforward monitoring of learned representations with minimal setup.
  • Open Source: Open for contributions and enhancements from the community, including any new metrics that are proposed.

Reptrix is the perfect tool for machine learning practitioners looking to quantitatively analyze learned representations and enhance the interpretability of their deep learning models, especially models trained in a self-supervised learning framework. To learn more about why these metrics are essential in modern DL workflows, check out our blogpost on Assessing Representation Quality in SSL

List of metrics currently supported

  • $\alpha$-ReQ: This metric computes the eigenvalues of the covariance matrix of the representations and fits a power-law distribution to them. The exponent of the power-law distribution is called the $\alpha$ exponent, which measures the heavy-tailedness of the distribution. A lower alpha exponent indicates that the representations are more discriminative.
  • RankMe: This metric computes the rank of the covariance matrix of the representations. A higher rank indicates representations of higher capacity.
  • LiDAR: This metric computes the rank of the linear discriminant analysis (LDA) matrix. A higher rank indicates representations with higher degree of seperability among object manifolds.

ResNet50

Metric Time to compute (s)
$\alpha$-ReQ 2.400
RankMe 2.364
LiDAR 7.929

ViT

Metric Time to compute (s)
$\alpha$-ReQ 0.137
RankMe 0.091
LiDAR 0.162

Using Reptrix in your own workflow

  1. Load your favourite pretrained network.
encoder = torch.hub.load('facebookresearch/barlowtwins:main', 'resnet50')
# Remove the final fully connected layer so that the model outputs the 2048 feature vector
encoder = torch.nn.Sequential(*(list(encoder.children())[:-1]))
encoder.eval()
  1. Extract features from the pretrained network.
def get_features(encoder_network, dataloader, transform=None, num_augmentations=10):
    # Loop over the dataset and collect the representations
    all_features = []

    # Loop over the dataset and collect the representations
    for i, data in enumerate(tqdm(dataloader, 0)):
        inputs, _ = data
        if transform:
            inputs = torch.cat([transform(inputs) for _ in range(num_augmentations)], dim=0)
        with torch.no_grad():
            features = encoder_network(inputs)
        if transform:
            # put the augmentations in an additonal dimension
            features = features.reshape(-1, num_augmentations, features.shape[1])
        all_features.append(features)


    # Concatenate all the features
    all_features = torch.cat(all_features, dim=0)
    return all_features

all_representations = get_features(encoder, loader)
num_augmentations = 10
all_representations_lidar = get_features(encoder, loader,
                                transform=transform_augs,
                                num_augmentations=num_augmentations)
num_samples = all_representations_lidar.shape[0]
  1. Compute the representation metrics
from reptrix import alpha, rankme, lidar
metric_alpha = alpha.get_alpha(all_representations)
metric_rankme = rankme.get_rankme(all_representations)
metric_lidar = lidar.get_lidar(all_representations_lidar, num_samples,
                            num_augmentations,
                            del_sigma_augs=0.00001)

Installation

Using pypi

You can install the latest version of reptrix using:

pip install reptrix

Manual installation

You can clone this repository and manually install it with:

pip install git+https://github.com/BARL-SSL/reptrix

Setup Conda environment for examples

You can incorporate reptrix in your existing conda environment or create a new environment with the necessary packages:

conda env create -f conda_env.yaml
conda activate reptrix
pip install -e .

Example code for Reptrix

We provide a tutorial iPython notebook that shows how you can incorporate metrics from our Reptrix library to your own code.

Related papers and Citations

This library currently supports metrics proposed in three different papers:

  1. $\alpha$[-ReQ : Assessing Representation Quality in Self-Supervised Learning by measuring eigenspectrum decay (NeurIPS 2022])(https://proceedings.neurips.cc/paper_files/paper/2022/hash/70596d70542c51c8d9b4e423f4bf2736-Abstract-Conference.html)
  2. RankMe: Assessing the Downstream Performance of Pretrained Self-Supervised Representations by Their Rank (ICML 2023)
  3. LiDAR: Sensing Linear Probing Performance in Joint Embedding SSL Architectures (ICLR 2024)

Contact

For questions related to this code, please raise an issue and you can mail us: Arna Ghosh, Arnab K Mondal, Danielle Benesch, Kumar K Agrawal

Contributing

You can check out the contributor's guide.

This project uses pre-commit, you can install it before making any changes:

pip install pre-commit
cd reptrix
pre-commit install

It is a good idea to update the hooks to the latest version::

pre-commit autoupdate

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

reptrix-0.1.0.tar.gz (89.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

reptrix-0.1.0-py3-none-any.whl (17.6 kB view details)

Uploaded Python 3

File details

Details for the file reptrix-0.1.0.tar.gz.

File metadata

  • Download URL: reptrix-0.1.0.tar.gz
  • Upload date:
  • Size: 89.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.3

File hashes

Hashes for reptrix-0.1.0.tar.gz
Algorithm Hash digest
SHA256 53d1b5254157024e7a5550e3cfffc6fd9bc84ac19903c0bce4c8cbd6497fb776
MD5 1727c2c3709226463c5e3894d307fc63
BLAKE2b-256 c9cd2e3986b034631a442e7b3833bdbcc8cae3f5303334f67e1b786a3eb826de

See more details on using hashes here.

File details

Details for the file reptrix-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: reptrix-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 17.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.3

File hashes

Hashes for reptrix-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 7f8f2dcd1ad1306df019cefb56eb02404704ca55e84ca3d2cb970341021efdf3
MD5 4b25d273f2133a3a2c1251106532ada4
BLAKE2b-256 5901308697ff403cf28984e4a889ecf03de9f677e72ebe700eaff983c410bfb1

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page