Library that provides metrics to assess representation quality
Project description
Representation quality metrics for pretrained deep models! ⭐
Reptrix
About
Reptrix, short for Representation Metrics, is a PyTorch library designed to simplify the evaluation of representation quality metrics in pretrained deep neural networks. Reptrix offers a suite of recently proposed metrics, predominanty in the vision self-supervised learning literature, that are essential for researchers and engineers focusing on design, deployment, evaluation and interpretability of deep neural networks in computer vision settings.
Key Features:
- Comprehensive Metric Suite: Includes a variety of metrics to assess various aspects of representation quality, that are indicative of capacity, robustness and downstream task performance.
- PyTorch Integration: Seamlessly integrates with existing PyTorch models and workflows, allowing for straightforward monitoring of learned representations with minimal setup.
- Open Source: Open for contributions and enhancements from the community, including any new metrics that are proposed.
Reptrix is the perfect tool for machine learning practitioners looking to quantitatively analyze learned representations and enhance the interpretability of their deep learning models, especially models trained in a self-supervised learning framework. To learn more about why these metrics are essential in modern DL workflows, check out our blogpost on Assessing Representation Quality in SSL
List of metrics currently supported
- $\alpha$-ReQ: This metric computes the eigenvalues of the covariance matrix of the representations and fits a power-law distribution to them. The exponent of the power-law distribution is called the $\alpha$ exponent, which measures the heavy-tailedness of the distribution. A lower alpha exponent indicates that the representations are more discriminative.
- RankMe: This metric computes the rank of the covariance matrix of the representations. A higher rank indicates representations of higher capacity.
- LiDAR: This metric computes the rank of the linear discriminant analysis (LDA) matrix. A higher rank indicates representations with higher degree of seperability among object manifolds.
ResNet50
| Metric | Time to compute (s) |
|---|---|
| $\alpha$-ReQ | 2.400 |
| RankMe | 2.364 |
| LiDAR | 7.929 |
ViT
| Metric | Time to compute (s) |
|---|---|
| $\alpha$-ReQ | 0.137 |
| RankMe | 0.091 |
| LiDAR | 0.162 |
Using Reptrix in your own workflow
- Load your favourite pretrained network.
encoder = torch.hub.load('facebookresearch/barlowtwins:main', 'resnet50')
# Remove the final fully connected layer so that the model outputs the 2048 feature vector
encoder = torch.nn.Sequential(*(list(encoder.children())[:-1]))
encoder.eval()
- Extract features from the pretrained network.
def get_features(encoder_network, dataloader, transform=None, num_augmentations=10):
# Loop over the dataset and collect the representations
all_features = []
# Loop over the dataset and collect the representations
for i, data in enumerate(tqdm(dataloader, 0)):
inputs, _ = data
if transform:
inputs = torch.cat([transform(inputs) for _ in range(num_augmentations)], dim=0)
with torch.no_grad():
features = encoder_network(inputs)
if transform:
# put the augmentations in an additonal dimension
features = features.reshape(-1, num_augmentations, features.shape[1])
all_features.append(features)
# Concatenate all the features
all_features = torch.cat(all_features, dim=0)
return all_features
all_representations = get_features(encoder, loader)
num_augmentations = 10
all_representations_lidar = get_features(encoder, loader,
transform=transform_augs,
num_augmentations=num_augmentations)
num_samples = all_representations_lidar.shape[0]
- Compute the representation metrics
from reptrix import alpha, rankme, lidar
metric_alpha = alpha.get_alpha(all_representations)
metric_rankme = rankme.get_rankme(all_representations)
metric_lidar = lidar.get_lidar(all_representations_lidar, num_samples,
num_augmentations,
del_sigma_augs=0.00001)
Installation
Using pypi
You can install the latest version of reptrix using:
pip install reptrix
Manual installation
You can clone this repository and manually install it with:
pip install git+https://github.com/BARL-SSL/reptrix
Setup Conda environment for examples
You can incorporate reptrix in your existing conda environment or create a new environment with the necessary packages:
conda env create -f conda_env.yaml
conda activate reptrix
pip install -e .
Example code for Reptrix
We provide a tutorial iPython notebook that shows how you can incorporate metrics from our Reptrix library to your own code.
Related papers and Citations
This library currently supports metrics proposed in three different papers:
- $\alpha$[-ReQ : Assessing Representation Quality in Self-Supervised Learning by measuring eigenspectrum decay (NeurIPS 2022])(https://proceedings.neurips.cc/paper_files/paper/2022/hash/70596d70542c51c8d9b4e423f4bf2736-Abstract-Conference.html)
- RankMe: Assessing the Downstream Performance of Pretrained Self-Supervised Representations by Their Rank (ICML 2023)
- LiDAR: Sensing Linear Probing Performance in Joint Embedding SSL Architectures (ICLR 2024)
Contact
For questions related to this code, please raise an issue and you can mail us: Arna Ghosh, Arnab K Mondal, Danielle Benesch, Kumar K Agrawal
Contributing
You can check out the contributor's guide.
This project uses pre-commit, you can install it before making any changes:
pip install pre-commit
cd reptrix
pre-commit install
It is a good idea to update the hooks to the latest version::
pre-commit autoupdate
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file reptrix-0.1.0.tar.gz.
File metadata
- Download URL: reptrix-0.1.0.tar.gz
- Upload date:
- Size: 89.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
53d1b5254157024e7a5550e3cfffc6fd9bc84ac19903c0bce4c8cbd6497fb776
|
|
| MD5 |
1727c2c3709226463c5e3894d307fc63
|
|
| BLAKE2b-256 |
c9cd2e3986b034631a442e7b3833bdbcc8cae3f5303334f67e1b786a3eb826de
|
File details
Details for the file reptrix-0.1.0-py3-none-any.whl.
File metadata
- Download URL: reptrix-0.1.0-py3-none-any.whl
- Upload date:
- Size: 17.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7f8f2dcd1ad1306df019cefb56eb02404704ca55e84ca3d2cb970341021efdf3
|
|
| MD5 |
4b25d273f2133a3a2c1251106532ada4
|
|
| BLAKE2b-256 |
5901308697ff403cf28984e4a889ecf03de9f677e72ebe700eaff983c410bfb1
|