Skip to main content

atlas learning

Project description

neurve

This is the repository to accompany the paper Self-supervised representation learning on manifolds, to be presented at the ICLR 2021 Workshop on Geometrical and Topological Representation Learning.

Additionally, we implement a manifold version of triplet training, which will be expounded on in an upcoming preprint.

Notebooks

MSimCLR Inference Open In Colab

This notebook will run inference using a pre-trained Manifold SimCLR model (trained on either CIFAR10, FashionMNIST, or MNIST).

Installation

Install via

pip install neurve

or, to install with Weights & Biases support, run:

pip install "neurve[wandb]"

You can also install from source by cloning this repository and then running, from the repo root, the command

pip install . # or pip install .[wandb]

The dependencies are

numpy>=1.17.4
torch>=1.3.1
torchvision>=0.4.2
scipy>=1.5.3 (for parsing the cars dataset annotations)
tqdm
tensorboardX

Datasets

To get the datasets for metric learning (the datasets we use for representation learning are included in torchvision.datasets):

Training commands

Tracking with Weights & Biases

To use Weights & Biases to log training/validation metrics and for storing model checkpoints, set the environment variable NEURVE_TRACKER to wandb. Otherwise tensorboardX will be used for metric logging and model checkpoints will be saved locally.

Manifold SimCLR

For self-supervised training, run the command

python experiments/simclr.py \
              --dataset $DATASET \
              --backbone $BACKBONE \
              --dim_z $DIM_Z \
              --n_charts $N_CHARTS \
              --n_epochs $N_EPOCHS \
              --tau $TAU \
              --out_path $OUT_PATH # if not using Weights & Biases for tracking

where

  • $DATASET is one of "cifar", "mnist", "fashion_mnist".
  • $BACKBONE is the name of the backbone network (in the paper we used "resnet50" for CIFAR10 and "resnet18" for MNIST and FashionMNIST).
  • $DIM_Z and $N_CHARTS are the dimension and number of charts, respectively, for the manifold.
  • $N_EPOCHS is the number of epochs to train for (in the paper we used 1,000 for CIFAR10 and 100 for MNIST and FashionMNIST).
  • $TAU is the temperature parameter for the contrastive loss function (in the paper we used 0.5 for CIFAR10 and 1.0 for MNIST and FashionMNIST).
  • $OUT_PATH is the path to save model checkpoints and tensorboard output.

Manifold metric learning

To train metric learning, run the command

python experiments/triplet.py \
              --data_root $DATA_ROOT \
              --dim_z $DIM_Z \
              --n_charts $N_CHARTS \
              --out_path $OUT_PATH # if not using Weights & Biases for tracking

where

  • $DATA_ROOT is the path to the data (e.g. data/CUB_200_2011/images/ or data/cars/), which should be a folder of subfolders, where each subfolder has the images for one class.
  • $DIM_Z and $N_CHARTS are the dimension and number of charts, respectively, for the manifold.
  • $OUT_PATH is the path to save model checkpoints and tensorboard output.

Citation

@inproceedings{
  korman2021selfsupervised,
  title={Self-supervised representation learning on manifolds},
  author={Eric O Korman},
  booktitle={ICLR 2021 Workshop on Geometrical and Topological Representation Learning},
  year={2021},
  url={https://openreview.net/forum?id=EofGDIGAhvR}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

neurve-0.1.0.tar.gz (13.7 MB view details)

Uploaded Source

Built Distribution

neurve-0.1.0-py3-none-any.whl (18.9 kB view details)

Uploaded Python 3

File details

Details for the file neurve-0.1.0.tar.gz.

File metadata

  • Download URL: neurve-0.1.0.tar.gz
  • Upload date:
  • Size: 13.7 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.7

File hashes

Hashes for neurve-0.1.0.tar.gz
Algorithm Hash digest
SHA256 37e83d7e27ba3245718319ca6e322994d2616f40577c5abd2f6a344da02906d1
MD5 12f3e76ac4ba97b4add8b143f0dbbc88
BLAKE2b-256 58f0115f609b83f77e1bb6959efed7992124c0790dad0b2ee4129202de7ac0d8

See more details on using hashes here.

File details

Details for the file neurve-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: neurve-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 18.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.7

File hashes

Hashes for neurve-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 4abe5409d81b6ec2a70a1ecb49f2ab51ec0e85da8684b6a4cb084ed75c994d52
MD5 ff72f4a42c7087b8cbda6a9bf8f7f6bd
BLAKE2b-256 8d09f4d4cf2f58e281a3e92b0dc534ee24490945096a7244ae254d6fc8e63a3c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page