Generative modeling and representation learning through reconstruction
Project description
A variety of autoencoder structured models for generative modeling and/or representation learning in pytorch. Models are mostly designed for usability/extensability/research rather than production implementations. But, go ahead and train some models and reconstruct some things!
Table of contents
Models
LinearAE
A fully-connected autoencoder with a linear/multi-layer perceptron encoder and decoder
Reducing the Dimensionality of Data with Neural Networks
import torch
from autoencodersplz.models import LinearAE
model = LinearAE(
img_size = 224,
in_chans = 3,
hidden_layers = [64, 64],
dropout_rate = 0,
latent_dim = 16,
beta = 0.1, # beta > 0 = variational
max_temperature = 1000, # kld temperature annealing
device = None
)
img = torch.rand(1, 3, 224, 224)
loss, reconstructed_img = model(img)
LinearResidualAE
A fully-connected autoencoder with a linear/multi-layer perceptron residual network encoder and decoder
Skip Connections Eliminate Singularities
import torch
from autoencodersplz.models import LinearResidualAE
model = LinearResidualAE(
img_size = 224,
in_chans = 3,
hidden_dim = [64, 64],
blocks = [2, 2],
dropout_rate = 0.1,
with_batch_norm = False,
latent_dim = 16,
beta = 0.1, # beta > 0 = variational
max_temperature = 1000, # kld temperature annealing
device = None,
)
img = torch.rand(1, 3, 224, 224)
loss, reconstructed_img = model(img)
ConvResidualAE
A convolutional autoencoder with a ResNet encoder and symmetric decoder
Deep Residual Learning for Image Recognition
import torch
from autoencodersplz.models import ConvResidualAE
model = ConvResidualAE(
img_size = 224,
in_chans = 3,
channels = [64, 128, 256, 512],
blocks = [2, 2, 2, 2],
latent_dim = 16,
beta = 0, # beta > 0 = variational
max_temperature = 1000, # kld temperature annealing
upsample_mode = 'nearest', # interpolation method
device = None,
)
img = torch.rand(1, 3, 224, 224)
loss, reconstructed_img = model(img)
VQVAE
A vector-quantized variational autoencoder with a ResNet encoder and symmetric decoder
Neural Discrete Representation Learning
import torch
from autoencodersplz.models import VQVAE
model = VQVAE(
img_size = 224,
in_chans = 3,
channels = [64, 128, 256, 512],
blocks = [2, 2, 2, 2],
codebook_size = 256,
codebook_dim = 8,
use_cosine_sim = True,
kmeans_init = True,
commitment_weight = 0.5,
upsample_mode = 'nearest',
vq_kwargs = {},
)
img = torch.rand(1, 3, 224, 224)
loss, reconstructed_img = model(img)
FSQVAE
A finite-scalar quantized variational autoencoder with a ResNet encoder and symmetric decoder
Finite Scalar Quantization: VQ-VAE Made Simple
import torch
from autoencodersplz.models import FSQVAE
model = FSQVAE(
img_size = 224,
in_chans = 3,
channels = [64, 128, 256, 512],
blocks = [2, 2, 2, 2],
levels = [8, 6, 5],
upsample_mode = 'nearest'
)
img = torch.rand(1, 3, 224, 224)
loss, reconstructed_img = model(img)
MAE
A masked autoencoder with a vision transformer encoder and decoder
Masked Autoencoders Are Scalable Vision Learners
import torch
import torch.nn as nn
from autoencodersplz.models import MAE
model = MAE(
img_size = 224,
patch_size = 16,
in_chans = 3,
mask_ratio = 0.5,
embed_dim = 768,
depth = 12,
num_heads = 12,
mlp_ratio = 4,
pre_norm = False,
decoder_embed_dim = 768,
decoder_depth = 12,
decoder_num_heads = 12,
norm_layer = torch.nn.LayerNorm,
patch_norm_layer = torch.nn.LayerNorm,
post_norm_layer = torch.nn.LayerNorm,
)
img = torch.rand(1, 3, 224, 224)
loss, reconstructed_img = model(img)
MAEMix
A masked autoencoder with a MLP-mixer encoder and decoder
MLP-Mixer: An all-MLP Architecture for Vision
import torch
from autoencodersplz.models import MAEMix
model = MAEMix(
img_size = 224,
patch_size = 16,
in_chans = 3,
mask_ratio = 0.5,
embed_dim = 768,
depth = 12,
mlp_ratio = 4,
decoder_embed_dim = 768,
decoder_depth = 12,
)
img = torch.rand(1, 3, 224, 224)
loss, reconstructed_img = model(img)
IJEPA
Image-based joint-embedding predictive architecture (Thanks to Yiran for porting this implementation)
Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture
import torch
from autoencodersplz.models import IJEPA
model = IJEPA(
img_size = 224,
patch_size = 16,
in_chans = 3,
embed_dim = 768,
depth = 12,
num_heads = 12,
mlp_ratio = 4,
embed_dim_predictor = 384,
predictor_depth = 12,
num_targets = 4,
target_aspect_ratio = 0.75,
target_scale = 0.2,
context_aspect_ratio = 1.,
context_scale = 0.9
)
img = torch.rand(1, 3, 224, 224)
loss, reconstructed_img = model(img)
Training
Basic
The Trainer
class enables basic training using a single CPU or GPU for any model in the autoencodersplz
library. The Trainer
class will also automatically save the autoencoder model, backbone/encoder, losses, and a visualization of the training process (.gif
) if you provide a path to the output_dir
argument.
from autoencodersplz.trainers import Trainer
trainer = Trainer(
autoencoder,
train = train_dataloader,
valid = valid_dataloader,
epochs = 128,
learning_rate = 5e-4,
betas = (0.9, 0.95),
weight_decay = 0.05,
patience = 10,
scheduler = 'plateau',
save_backbone = True,
show_plots = False,
output_dir = 'training_run/',
device = None,
)
trainer.fit()
By default, Trainer
uses an AdamW
optimizer and either a CosineDecay
('cosine') or ReduceLROnPlateau
('plateau') scheduler. If you want to use different optimizers or schedulers, just re-assign a new optimizer or scheduler to the .optimizer
or .scheduler
attributes (with trainer.model.parameters()
) prior to calling trainer.fit()
.
Lightning
To make it easier to scale to multi-gpu/distributed training, all autoencodersplz
models are configured for use with pytorch lightning. Each model is setup with a default optimizer and scheduler and can be directly called by the pytorch lightning trainer. See an example below.
import lightning.pytorch as pl
from autoencodersplz.models import FSQVAE
model = FSQVAE(
img_size = 28,
in_chans = 1,
channels = [8, 16],
blocks = [1, 1],
levels = [8],
upsample_mode = 'nearest'
learning_rate = 1e-3,
factor = 0.1,
patience = 30,
min_lr = 1e-6
)
trainer = pl.Trainer(gpus=4, max_epochs=256)
trainer.fit(model, train_dataloader, valid_dataloader)
Examples
Basic usage
Here's a basic example of training a fully connected autoencoder on MNIST. The data is downloaded and loaded and then the autoencoder is fit. The training info is logged to the output directory (training/
) and a GIF of the training routine is generated for visual inspection.
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from autoencodersplz.models import LinearAE
from autoencodersplz.trainers import Trainer
train_loader = DataLoader(
MNIST(root='data/', train=True, download=True, transform=ToTensor()),
batch_size = 32,
shuffle = True,
)
test_loader = DataLoader(
MNIST(root='data/', train=False, download=True, transform=ToTensor()),
batch_size = 32,
shuffle = False,
)
model = LinearAE(
img_size = 28,
in_chans = 1,
hidden_layers = [256, 128],
dropout_rate = 0,
latent_dim = 32,
beta = 0,
)
trainer = Trainer(
model,
train_loader,
test_loader,
epochs = 32,
learning_rate = 1e-3,
output_dir = 'training/'
)
trainer.fit()
References
@article{hinton2006reducing,
title = {Reducing the dimensionality of data with neural networks},
author = {Geoffrey Hinton and Ruslan Salakhutdinov},
url = {10.1126/science.1127647},
year = {2006},
}
@article{orhan2018skip,
title = {Skip Connections Eliminate Singularities},
author = {Emin Orhan and Xaq Pitkow},
url = {https://arxiv.org/abs/1701.09175},
year = {2018},
}
@article{he2015deep,
title = {Deep Residual Learning for Image Recognition},
author = {Kaiming He and Xiangyu Zhang and Shaoqing Ren and Jian Sun},
url = {https://arxiv.org/abs/1512.03385},
year = {2016},
}
@misc{oord2018neural,
title={Neural Discrete Representation Learning},
author={Aaron van den Oord and Oriol Vinyals and Koray Kavukcuoglu},
url = {https://arxiv.org/abs/1711.00937},
year={2017},
}
@misc{mentzer2023finite,
title = {Finite Scalar Quantization: VQ-VAE Made Simple},
author = {Fabian Mentzer and David Minnen and Eirikur Agustsson and Michael Tschannen},
url = {https://arxiv.org/abs/2309.15505},
year = {2023},
}
@misc{he2021masked,
title = {Masked Autoencoders Are Scalable Vision Learners},
author = {Kaiming He and Xinlei Chen and Saining Xie and Yanghao Li and Piotr Dollár and Ross Girshick},
url = {https://arxiv.org/abs/2111.06377},
year = {2021},
}
@misc{tolstikhin2021mlpmixer,
title = {MLP-Mixer: An all-MLP Architecture for Vision},
author = {Ilya Tolstikhin and Neil Houlsby and Alexander Kolesnikov and Lucas Beyer and Xiaohua Zhai and Thomas Unterthiner and Jessica Yung and Andreas Steiner and Daniel Keysers and Jakob Uszkoreit and Mario Lucic and Alexey Dosovitskiy},
url = {https://arxiv.org/abs/2105.01601},
year = {2021},
}
@misc{assran2023selfsupervised,
title = {Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture},
author = {Mahmoud Assran and Quentin Duval and Ishan Misra and Piotr Bojanowski and Pascal Vincent and Michael Rabbat and Yann LeCun and Nicolas Ballas},
url = {https://arxiv.org/abs/2301.08243},
year = {2023},
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file autoencodersplz-2023.12.7.tar.gz
.
File metadata
- Download URL: autoencodersplz-2023.12.7.tar.gz
- Upload date:
- Size: 7.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 12032474092f552ae6272e0cc19afdc0536d3cfc5977132d9b602bc549c3b2ff |
|
MD5 | f77a45c0a021a8831a716fe248fb9910 |
|
BLAKE2b-256 | ef28128d1468752a84c16793c7be70486662c58cda671f05ad7cb546be273329 |
File details
Details for the file autoencodersplz-2023.12.7-py3-none-any.whl
.
File metadata
- Download URL: autoencodersplz-2023.12.7-py3-none-any.whl
- Upload date:
- Size: 5.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 284fa12fe1033e562cd8a991b740db1fb45a644b0d5fa00a082aa5c45785f612 |
|
MD5 | c16da43a41c358d9f0310c3c415fdbe5 |
|
BLAKE2b-256 | a53ddafcd25fe82d65e425d93aa5dfa4b22cc03394e9dcedc699dd02694c7a8c |