Visualization tool for unsupervised image visualization (ICLR 2023)
Project description
Unsupervised visualization of image datasets using contrastive learning
This is the code for the paper “Unsupervised visualization of image datasets using contrastive learning” (ICLR 2023).
If you use the code, please cite our paper:
@inproceedings{boehm2023unsupervised,
title={Unsupervised visualization of image datasets using contrastive learning},
author={B{\"o}hm, Jan Niklas and Berens, Philipp and Kobak, Dmitry},
booktitle={International Conference on Learning Representations},
year={2023},
}
We show that it is possible to visualize datasets such as CIFAR-10 and CIFAR-100 in 2D with a contrastive learning objective, while preserving a lot of structure! We call our method t-SimCNE.
Installation
Installation should be as easy as calling pip install .
in the project root, i. e.:
git clone https://github.com/berenslab/t-simcne
cd t-simcne
pip install .
Since the project uses a pyproject.toml
file, you need to make sure that pip version is at least v22.3.1
.
Usage example
import torch
import torchvision
from matplotlib import pyplot as plt
from tsimcne.tsimcne import TSimCNE
# get the cifar dataset (make sure to adapt `data_root` to point to your folder)
data_root = "experiments/cifar/out/cifar10"
dataset_train = torchvision.datasets.CIFAR10(
root=data_root,
download=True,
train=True,
)
dataset_test = torchvision.datasets.CIFAR10(
root=data_root,
download=True,
train=False,
)
dataset_full = torch.utils.data.ConcatDataset([dataset_train, dataset_test])
# create the object (here we run t-SimCNE with fewer epochs
# than in the paper; there we used [1000, 50, 450]).
tsimcne = TSimCNE(total_epochs=[500, 50, 250])
# train on the augmented/contrastive dataloader (this takes the most time)
tsimcne.fit(dataset_full)
# map the original images to 2D
Y = tsimcne.transform(dataset_full)
# get the original labels from the dataset
labels = [lbl for img, lbl in dataset_full]
# plot the data
fig, ax = plt.subplots()
ax.scatter(*Y.T, c=labels)
fig.savefig("tsimcne.png")
CIFAR-10
CIFAR-100
Duplicates and oddities in CIFAR datasets
We found out that there are >150 duplicates of just three separate images in CIFAR-10! Apparently this has not been discovered or discussed anywhere else and we basically stumbled upon this by exploring the visualizations.
Furthermore there seems to be some quite strange images in CIFAR-10:
And finally, there is a whole class of flatfishes in CIFAR-100, that seem to be misplaced, but they actually consist of caught flatfishes along with fishermen.
Reproducibility
The figures are in figures/
and have been created with the script files ending in .do
in media/
. If you want to reproduce those figures you need to use redo
and change some variables in redo.py
so that it runs. And you probably want an available GPU/GPU cluster.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.