Skip to main content

Visualization tool for unsupervised image visualization (ICLR 2023)

Project description

Unsupervised visualization of image datasets using contrastive learning

This is the code for the paper “Unsupervised visualization of image datasets using contrastive learning” (ICLR 2023).

If you use the code, please cite our paper:

@inproceedings{boehm2023unsupervised,
  title={Unsupervised visualization of image datasets using contrastive learning},
  author={B{\"o}hm, Jan Niklas and Berens, Philipp and Kobak, Dmitry},
  booktitle={International Conference on Learning Representations},
  year={2023},
}

We show that it is possible to visualize datasets such as CIFAR-10 and CIFAR-100 in 2D with a contrastive learning objective, while preserving a lot of structure! We call our method t-SimCNE.

arch

Installation

Installation should be as easy as calling:

pip install tsimcne

The package is now available on PyPI. If you want to install it from source, you can do as follows.

git clone https://github.com/berenslab/t-simcne
cd t-simcne
pip install .

Since the project uses a pyproject.toml file, you need to make sure that pip version is at least v22.3.1.

Usage example

The documentation is available at readthedocs. Below is a simple usage example.

import torch
import torchvision
from matplotlib import pyplot as plt
from tsimcne.tsimcne import TSimCNE

# get the cifar dataset (make sure to adapt `data_root` to point to your folder)
data_root = "experiments/cifar/out/cifar10"
dataset_train = torchvision.datasets.CIFAR10(
    root=data_root,
    download=True,
    train=True,
)
dataset_test = torchvision.datasets.CIFAR10(
    root=data_root,
    download=True,
    train=False,
)
dataset_full = torch.utils.data.ConcatDataset([dataset_train, dataset_test])

# create the object (here we run t-SimCNE with fewer epochs
# than in the paper; there we used [1000, 50, 450]).
tsimcne = TSimCNE(total_epochs=[500, 50, 250])

# train on the augmented/contrastive dataloader (this takes the most time)
tsimcne.fit(dataset_full)

# map the original images to 2D
Y = tsimcne.transform(dataset_full)

# get the original labels from the dataset
labels = [lbl for img, lbl in dataset_full]

# plot the data
fig, ax = plt.subplots()
ax.scatter(*Y.T, c=labels)
fig.savefig("tsimcne.png")

CIFAR-10

annotated plot of cifar10

CIFAR-100

label density for cifar100

Reproducibility

For reproducing the results of the paper, please see the iclr2023 branch in this repository.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tsimcne-0.4.23.tar.gz (19.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tsimcne-0.4.23-py3-none-any.whl (20.8 kB view details)

Uploaded Python 3

File details

Details for the file tsimcne-0.4.23.tar.gz.

File metadata

  • Download URL: tsimcne-0.4.23.tar.gz
  • Upload date:
  • Size: 19.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: pdm/2.22.4 CPython/3.13.2 Linux/6.13.7-arch1-1

File hashes

Hashes for tsimcne-0.4.23.tar.gz
Algorithm Hash digest
SHA256 780c96bb9a71abe0df97bb368be6c203d3c8eeb5915aadb35802c822cfd6bfad
MD5 c25b5a011ddaafc661ef1ab26dd11a40
BLAKE2b-256 c16bd68953dd181f1beac04f6733fe12a81bff1a908e6c349e028fca94c2222e

See more details on using hashes here.

File details

Details for the file tsimcne-0.4.23-py3-none-any.whl.

File metadata

  • Download URL: tsimcne-0.4.23-py3-none-any.whl
  • Upload date:
  • Size: 20.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: pdm/2.22.4 CPython/3.13.2 Linux/6.13.7-arch1-1

File hashes

Hashes for tsimcne-0.4.23-py3-none-any.whl
Algorithm Hash digest
SHA256 8e4f8e05ab9ea7b13022a1f46dd2c9c98e8d2a2b7d057513596c45607aef8305
MD5 05bf7f377c87a0d0e9fe803ec1d39d68
BLAKE2b-256 db8a1f9a6b2a0256613b26e3ea671ee1cfc7b3322af32b23f2051b1cc093bca9

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page