Latte: Cross-framework Python Package for Evaluation of Latent-based Generative Models
Project description
Cross-framework Python Package for Evaluation of Latent-based Generative Models
Latte
Latte (for LATent Tensor Evaluation) is a cross-framework Python package for evaluation of latent-based generative models. Latte supports calculation of disentanglement and controllability metrics in both PyTorch (via TorchMetrics) and TensorFlow.
Installation
For developers working on local clone, cd
to the repo and replace latte
with .
. For example, pip install .[tests]
pip install latte-metrics # core (numpy only)
pip install latte-metrics[pytorch] # with torchmetrics wrapper
pip install latte-metrics[keras] # with tensorflow wrapper
pip install latte-metrics[tests] # for testing
Running tests locally
pip install .[tests]
pytest tests/ --cov=latte
Example
Functional API
import latte
from latte.functional.disentanglement.mutual_info import mig
import numpy as np
latte.seed(42)
z = np.random.randn(16, 8)
a = np.random.randn(16, 2)
mutual_info_gap = mig(z, a, discrete=False, reg_dim=[4, 3])
Modular API
import latte
from latte.metrics.core.disentanglement import MutualInformationGap
import numpy as np
latte.seed(42)
mig = MutualInformationGap()
# ...
# initialize data and model
# ...
for data, attributes in range(batches):
recon, z = model(data)
mig.update_state(z, attributes)
mig_val = mig.compute()
TorchMetrics API
import latte
from latte.metrics.torch.disentanglement import MutualInformationGap
import torch
latte.seed(42)
mig = MutualInformationGap()
# ...
# initialize data and model
# ...
for data, attributes in range(batches):
recon, z = model(data)
mig.update(z, attributes)
mig_val = mig.compute()
Keras Metric API
import latte
from latte.metrics.keras.disentanglement import MutualInformationGap
from tensorflow import keras as tfk
latte.seed(42)
mig = MutualInformationGap()
# ...
# initialize data and model
# ...
for data, attributes in range(batches):
recon, z = model(data)
mig.update_state(z, attributes)
mig_val = mig.result()
Documentation
https://latte.readthedocs.io/en/latest
Method Chart for Modular API
TorchMetrics: https://torchmetrics.readthedocs.io/en/latest/pages/implement.html
Keras Metric: https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Metric
Torch/Keras wrapper will
- convert torch/tf types to numpy types (and move everything to CPU)
- call native class methods
- if there are return values, convert numpy types back to torch/tf types
Native | TorchMetrics | Keras Metric | |
---|---|---|---|
base class | latte.metrics.LatteMetric |
torchmetrics.Metric |
tf.keras.metrics.Metric |
super class | object |
torch.nn.Module |
tf.keras.layers.Layer |
adding buffer | self.add_state |
self.add_state |
self.add_weight |
updating buffer | self.update_state |
self.update |
self.update_state |
resetting buffer | self.reset_state |
self.reset |
self.reset_state |
computing metric values | self.compute |
self.compute |
self.result |
Supported metrics
๐งช Beta support | โ๏ธ Stable | ๐จ In Progress | ๐ฃ In Queue | ๐ KIV |
Metric | Latte Functional | Latte Modular | TorchMetrics | Keras Metric |
---|---|---|---|---|
Disentanglement Metrics | ||||
๐ Mutual Information Gap (MIG) | ๐งช | ๐งช | ๐งช | ๐งช |
๐ Dependency-blind Mutual Information Gap (DMIG) | ๐งช | ๐งช | ๐งช | ๐งช |
๐ Dependency-aware Mutual Information Gap (XMIG) | ๐งช | ๐งช | ๐งช | ๐งช |
๐ Dependency-aware Latent Information Gap (DLIG) | ๐งช | ๐งช | ๐งช | ๐งช |
๐ Separate Attribute Predictability (SAP) | ๐งช | ๐งช | ๐งช | ๐งช |
๐ Modularity | ๐งช | ๐งช | ๐งช | ๐งช |
๐ ฮฒ-VAE Score | ๐ | ๐ | ๐ | ๐ |
๐ FactorVAE Score | ๐ | ๐ | ๐ | ๐ |
๐ DCI Score | ๐ | ๐ | ๐ | ๐ |
๐ Interventional Robustness Score (IRS) | ๐ | ๐ | ๐ | ๐ |
๐ Consistency | ๐ | ๐ | ๐ | ๐ |
๐ Restrictiveness | ๐ | ๐ | ๐ | ๐ |
Interpolatability Metrics | ||||
๐ Smoothness | ๐งช | ๐งช | ๐งช | ๐งช |
๐ Monotonicity | ๐งช | ๐งช | ๐งช | ๐งช |
๐ Latent Density Ratio | ๐ฃ | ๐ฃ | ๐ฃ | ๐ฃ |
๐ Linearity | ๐ | ๐ | ๐ | ๐ |
Bundled metric modules
๐งช Experimental (subject to changes) | โ๏ธ Stable | ๐จ In Progress | ๐ฃ In Queue
Metric Bundle | Latte Functional | Latte Modular | TorchMetrics | Keras Metric | Included |
---|---|---|---|---|---|
Dependency-aware Disentanglement | ๐งช | ๐จ | ๐ฃ | ๐ฃ | MIG, DMIG, XMIG, DLIG |
LIAD-based Interpolatability | ๐งช | ๐จ | ๐ฃ | ๐ฃ | Smoothness, Monotonicity |
Cite
For individual metrics, please cite the paper according to the link in the ๐ icon in front of each metric.
If you find our package useful please cite us as
@software{
watcharasupat2021latte,
author = {Watcharasupat, Karn N. and Lee, Junyoung and Lerch, Alexander},
title = {{Latte: Cross-framework Python Package for Evaluation of Latent-based Generative Models}},
url = {https://github.com/karnwatcharasupat/latte},
version = {0.0.1-alpha1}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file latte-metrics-0.0.1a2.tar.gz
.
File metadata
- Download URL: latte-metrics-0.0.1a2.tar.gz
- Upload date:
- Size: 17.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.8.2 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e87d9506f4a3ed7a06f0723bffef70cf384d0e18ad8a1490b1c3f3e7cc6ac4a1 |
|
MD5 | 637df3c3a96355074930e02ebcff0791 |
|
BLAKE2b-256 | 7b31fe9000b494e6913fc92e7a4e240ebd90be0496cc13f2090c4c0ba4a34456 |
File details
Details for the file latte_metrics-0.0.1a2-py3-none-any.whl
.
File metadata
- Download URL: latte_metrics-0.0.1a2-py3-none-any.whl
- Upload date:
- Size: 25.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.8.2 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | a1b7ab793d19227613bf7f75034001d9df01e13069bb8dfb6960236dc8c2e442 |
|
MD5 | eb747c7d3d15fdca395d30f151745d45 |
|
BLAKE2b-256 | 5e9883af344fbef6d62d8fa98500573a5718bb78488dccc78fd07eafd740230b |