Vae disentanglement framework built with pytorch lightning.
Project description
🧶 Disent
A modular disentangled representation learning framework for pytorch
⚠️ API is not yet stable
Visit the docs for more info, or browse the releases.
Contributions are welcome!
Overview
Disent is a modular disentangled representation learning framework for auto-encoders, built upon pytorch-lightning. This framework consists of various composable components that can be used to build and benchmark disentanglement pipelines.
The name of the framework is derived from both disentanglement and scientific dissent.
Goals
Disent aims to fill the following criteria:
- Provide high quality, readable and easily comparable implementations of VAEs
- Use best practice eg.
torch.distributions
- Be extremely flexible & configurable
Citing Disent
Please use the following citation if you use Disent in your research:
@Misc{Michlo2021Disent,
author = {Nathan Juraj Michlo},
title = {Disent - A modular disentangled representation learning framework for pytorch},
howpublished = {Github},
year = {2021},
url = {https://github.com/nmichlo/disent}
}
Getting Started
WARNING: Disent is still under active development. Features and APIs are not considered stable, but should be expected to change! A very limited set of tests currently exist which will be expanded upon in time.
The easiest way to use disent is by running experiements/hydra_system.py
and changing the root config in experiements/config/config.yaml
. Configurations are managed with Hydra Config
Pypi:
-
Install with:
pip install disent
(This will most likely be outdated) -
Visit the docs!
Source:
-
Clone with:
git clone --branch dev https://github.com/nmichlo/disent.git
-
Change your working directory to the root of the repo:
cd disent
-
Install the requirements for python 3.8 with
pip3 install -r requirements.txt
-
Run the default experiment after configuring
experiments/config/config.yaml
by runningPYTHONPATH=. python3 experiments/run.py
Features
Disent includes implementations of modules, metrics and datasets from various papers. However modules marked with a "🧵" are introduced in disent for my MSc. research.
Frameworks
- Unsupervised:
- Weakly Supervised:
- Ada-GVAE
AdaVae(..., average_mode='gvae')
Usually better than the Ada-ML-VAE - Ada-ML-VAE
AdaVae(..., average_mode='ml-vae')
- Ada-GVAE
- Supervised:
- Experimental:
- 🧵 Ada-TVAE
- Adaptive Triplet VAE
- 🧵 DO-TVE
- Data Overlap Triplet Variational Encoder
- various others not worth mentioning
- 🧵 Ada-TVAE
Many popular disentanglement frameworks still need to be added, please submit an issue if you have a request for an additional framework.
todo
- FactorVAE
- GroupVAE
- MLVAE
Metrics
- Disentanglement:
- FactorVAE Score
- DCI
- MIG
- SAP
- Unsupervised Scores
- 🧵 Flatness Score
- Measures max width (furthest two points) over path length (sum of distances between consecutive points) of factor traversal embeddings. A combined measure of linearity and ordering, (weighted towards axis alignment if l2 width over l1 path length is used).
- 🧵 Flatness Components - Linearity, Monotonicity & Ordering
- Measure linearity of factor traversal embeddings using softmax-style metric over PCA variances computed over embeddings
- Measure axis-alignment of factor traversal embeddings using softmax-style metric over embedding variances
- Measure ordering of embeddings by checking anchor-positive and anchor-negative distances correspond to ground-truth factors
Some popular metrics still need to be added, please submit an issue if you wish to add your own, or you have a request.
Datasets:
Various common datasets used in disentanglement research are implemented, as well as new sythetic datasets that are generated programatically on the fly. These are convenient and lightweight, not requiring storage space.
-
Ground Truth:
- Cars3D
- dSprites
- MPI3D
- SmallNORB
- Shapes3D
-
Ground Truth Non-Overlapping (Synthetic):
- 🧵 XYBlocks: 3 blocks of decreasing size that move across a grid. Blocks can be one of three colors R, G, B. if a smaller block overlaps a larger one and is the same color, the block is xor'd to black.
- 🧵 XYSquares: 3 squares (R, G, B) that move across a non-overlapping grid. Obervations have no channel-wise loss overlap.
- 🧵 XYObject: A simplistic version of dSprites with a single square.
Input Transforms + Input/Target Augmentations
- Input based transforms are supported.
- Input and Target CPU and GPU based augmentations are supported.
Schedules/Annealing:
Hyper-parameter annealing is supported through the use of schedules. The currently implemented schedules include:
- Linear Schedule
- Cyclic Schedule
- Cosine Wave Schedule
- Various other wrapper schedules
Why?
-
Created as part of my Computer Science MSc scheduled for completion in 2021.
-
I needed custom high quality implementations of various VAE's.
-
A pytorch version of disentanglement_lib.
-
I didn't have time to wait for Weakly-Supervised Disentanglement Without Compromises to release their code as part of disentanglement_lib. (As of September 2020 it has been released, but has unresolved discrepencies).
-
disentanglement_lib still uses outdated Tensorflow 1.0, and the flow of data is unintuitive because of its use of Gin Config.
Architecture
disent
disent/data
: raw groundtruth datasetsdisent/dataset
: dataset wrappers & sampling strategiesdisent/framework
: frameworks, including Auto-Encoders and VAEsdisent/metrics
: metrics for evaluating disentanglement using ground truth datasetsdisent/model
: common encoder and decoder models used for VAE researchdisent/schedule
: annealing schedules that can be registered to a frameworkdisent/transform
: transform operations for processing & augmenting input and target data from datasets
experiment
experiment/run.py
: entrypoint for running basic experiments with hydra configexperiment/config
: root folder for hydra config filesexperiment/util
: various helper code, pytorch lightning callbacks & visualisation tools for experiments
Example Code
The following is a basic working example of disent that trains a BetaVAE with a cyclic beta schedule and evaluates the trained model with various metrics.
Basic Example
import pytorch_lightning as pl
from torch.optim import Adam
from torch.utils.data import DataLoader
from disent.data.groundtruth import XYObjectData
from disent.dataset.groundtruth import GroundTruthDataset
from disent.frameworks.vae.unsupervised import BetaVae
from disent.metrics import metric_dci, metric_mig
from disent.model.ae import EncoderConv64, DecoderConv64, AutoEncoder
from disent.schedule import CyclicSchedule
from disent.transform import ToStandardisedTensor
# We use this internally to test this script.
# You can remove all references to this in your own code.
from disent.util import is_test_run
# create the dataset & dataloaders
# - ToStandardisedTensor transforms images from numpy arrays to tensors and performs checks
data = XYObjectData()
dataset = GroundTruthDataset(data, transform=ToStandardisedTensor())
dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True)
# create the BetaVAE model
# - adjusting the beta, learning rate, and representation size.
module = BetaVae(
make_optimizer_fn=lambda params: Adam(params, lr=5e-4),
make_model_fn=lambda: AutoEncoder(
# z_multiplier is needed to output mu & logvar when parameterising normal distribution
encoder=EncoderConv64(x_shape=dataset.x_shape, z_size=6, z_multiplier=2),
decoder=DecoderConv64(x_shape=dataset.x_shape, z_size=6),
),
cfg=BetaVae.cfg(beta=0.004)
)
# cyclic schedule for target 'beta' in the config/cfg. The initial value from the
# config is saved and multiplied by the ratio from the schedule on each step.
# - based on: https://arxiv.org/abs/1903.10145
module.register_schedule('beta', CyclicSchedule(
period=1024, # repeat every: trainer.global_step % period
))
# train model
# - for 65536 batches/steps
trainer = pl.Trainer(logger=False, checkpoint_callback=False, max_steps=65536, fast_dev_run=is_test_run())
trainer.fit(module, dataloader)
# compute disentanglement metrics
# - we cannot guarantee which device the representation is on
# - this will take a while to run
get_repr = lambda x: module.encode(x.to(module.device))
metrics = {
**metric_dci(dataset, get_repr, num_train=10 if is_test_run() else 1000, num_test=5 if is_test_run() else 500, show_progress=True),
**metric_mig(dataset, get_repr, num_train=20 if is_test_run() else 2000),
}
# evaluate
print('metrics:', metrics)
Visit the docs for more examples!
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for disent-0.0.1.dev10-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | fb75d925c9f5183d0a75fb2ac80c31be19b1b9cade36e4311ab1f768bf796250 |
|
MD5 | 01c3ed94f951be71d75b9cbdd0ffddd8 |
|
BLAKE2b-256 | a9a9df4308e35b8c0a0fb2a4bda884b7673e38bc4524bb7d6ec5215a17a1b85e |