A diversity metric for machine learning
Project description
The Vendi Score: A Diversity Evaluation Metric for Machine Learning
This repository contains the implementation of the Vendi Score (VS), a metric for evaluating diversity in machine learning. The input to metric is a collection of samples and a pairwise similarity function, and the output is a number, which can be interpreted as the effective number of unique elements in the sample. Specifically, given a positive semi-definite matrix $K \in \mathbb{R}^{n \times n}$ of similarity scores, the score is defined as: $$\mathrm{VS}(K) = \exp(-\mathrm{tr}(K/n \log K/n)) = \exp(-\sum_{i=1}^n \lambda_i \log \lambda_i),$$ where $\lambda_i$ are the eigenvalues of $K/n$ and $0 \log 0 = 0$. That is, the Vendi Score is equal to the exponential of the von Neumann entropy of $K/n$, or the Shannon entropy of the eigenvalues, which is also known as the effective rank.
Installation
You can install vendi_score from pip:
pip install vendi_score
or by cloning this repository:
git clone https://github.com/vertaix/Vendi-Score.git
cd Vendi-Score
pip install -e .
vendi_score includes some optional dependencies for computing predefined similarity score between images, text, or molecules. You can install these dependencies with a command as in the following:
pip install vendi_score[images]
pip install vendi_score[text,molecules]
pip install vendi_score[all]
Usage
The input to vendi_score is a list of samples and a similarity function, k, mapping a pair of elements to a similarity score. k should be symmetric, and k(x, x) = 1:
import numpy as np
from vendi_score import vendi
samples = [0, 0, 10, 10, 20, 20]
k = lambda a, b: np.exp(-np.abs(a - b))
vendi.score(samples, k)
# 2.9999
If you already have precomputed a similarity matrix:
K = np.array([[1.0, 0.9, 0.0],
[0.9, 1.0, 0.0],
[0.0, 0.0, 1.0]])
vendi.score_K(K)
# 2.1573
If your similarity function is a dot product between normalized embeddings $X\in\mathbb{R}^{n\times d}$, and $d \leq n$, it is faster to compute the Vendi score using the covariance matrix, $\frac{1}{n} \sum_{i=1}^n x_i x_i^{\top}$:
vendi.score_dual(X)
If the rows of $X$ are not normalized, set normalize = True.
Similarity functions
Some similarity functions are provided in vendi_score.image_utils, vendi_score.text_utils, and vendi_score.molecule_utils. For example:
Images:
from torchvision import datasets
from vendi_score import image_utils
mnist = datasets.MNIST("data/mnist", train=False, download=True)
digits = [[x for x, y in mnist if y == c] for c in range(10)]
pixel_vs = [image_utils.pixel_vendi_score(imgs) for imgs in digits]
# The default embeddings are from the pool-2048 layer of the torchvision
# Inception v3 model.
inception_vs = [image_utils.embedding_vendi_score(imgs, device="cuda") for imgs in digits]
for y, (pvs, ivs) in enumerate(zip(pixel_vs, inception_vs)): print(f"{y}\t{pvs:.02f}\t{ivs:02f}")
# Output:
# 0 7.68 3.45
# 1 5.31 3.50
# 2 12.18 3.62
# 3 9.97 2.97
# 4 11.10 3.75
# 5 13.51 3.16
# 6 9.06 3.63
# 7 9.58 4.07
# 8 9.69 3.74
# 9 8.56 3.43
Text:
from vendi_score import text_utils
sents = ["Look, Jane.",
"See Spot.",
"See Spot run.",
"Run, Spot, run.",
"Jane sees Spot run."]
ngram_vs = text_utils.ngram_vendi_score(sents, ns=[1, 2])
bert_vs = text_utils.embedding_vendi_score(sents, model_path="bert-base-uncased")
simcse_vs = text_utils.embedding_vendi_score(sents, model_path="princeton-nlp/unsup-simcse-bert-base-uncased")
print(f"N-grams: {ngram_vs:.02f}, BERT: {bert_vs:.02f}, SimCSE: {simcse_vs:.02f})
# N-grams: 3.91, BERT: 1.21, SimCSE: 2.81
More examples are illustrated in Jupyter notebooks in the examples/ folder.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file vendi-score-0.0.3.tar.gz.
File metadata
- Download URL: vendi-score-0.0.3.tar.gz
- Upload date:
- Size: 13.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.8.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7b133fd293d63038aea032b2933c68a7040991ee91d9b953fb6b1ede43526c53
|
|
| MD5 |
e9576c25cbdda4d10170721401a3510a
|
|
| BLAKE2b-256 |
7f4cffff6368e4f13a17b8b65df59c868f6a2c8c9feadaaa203e7b6ac5e5f659
|