atlas learning
Project description
neurve
This is the repository to accompany the paper Self-supervised representation learning on manifolds, to be presented at the ICLR 2021 Workshop on Geometrical and Topological Representation Learning.
Additionally, we implement a manifold version of triplet training, which will be expounded on in an upcoming preprint.
Notebooks
This notebook will run inference using a pre-trained Manifold SimCLR model (trained on either CIFAR10, FashionMNIST, or MNIST).
Installation
Install via
pip install neurve
or, to install with Weights & Biases support, run:
pip install "neurve[wandb]"
You can also install from source by cloning this repository and then running, from the repo root, the command
pip install . # or pip install .[wandb]
The dependencies are
numpy>=1.17.4
torch>=1.3.1
torchvision>=0.4.2
scipy>=1.5.3 (for parsing the cars dataset annotations)
tqdm
tensorboardX
Datasets
To get the datasets for metric learning (the datasets we use for representation learning are included in torchvision.datasets
):
- CUB dataset: Download the file
CUB_200_2011.tgz
from http://www.vision.caltech.edu/visipedia/CUB-200-2011.html and decompress in thedata
folder. The folder structure should bedata/CUB_200_2011/images/
. - cars196 dataset: run
make data/cars
.
Training commands
Tracking with Weights & Biases
To use Weights & Biases to log training/validation metrics and for storing model checkpoints, set the environment variable NEURVE_TRACKER
to wandb
. Otherwise tensorboardX will be used for metric logging and model checkpoints will be saved locally.
Manifold SimCLR
For self-supervised training, run the command
python experiments/simclr.py \
--dataset $DATASET \
--backbone $BACKBONE \
--dim_z $DIM_Z \
--n_charts $N_CHARTS \
--n_epochs $N_EPOCHS \
--tau $TAU \
--out_path $OUT_PATH # if not using Weights & Biases for tracking
where
$DATASET
is one of"cifar"
,"mnist"
,"fashion_mnist"
.$BACKBONE
is the name of the backbone network (in the paper we used"resnet50"
for CIFAR10 and"resnet18"
for MNIST and FashionMNIST).$DIM_Z
and$N_CHARTS
are the dimension and number of charts, respectively, for the manifold.$N_EPOCHS
is the number of epochs to train for (in the paper we used 1,000 for CIFAR10 and 100 for MNIST and FashionMNIST).$TAU
is the temperature parameter for the contrastive loss function (in the paper we used 0.5 for CIFAR10 and 1.0 for MNIST and FashionMNIST).$OUT_PATH
is the path to save model checkpoints and tensorboard output.
Manifold metric learning
To train metric learning, run the command
python experiments/triplet.py \
--data_root $DATA_ROOT \
--dim_z $DIM_Z \
--n_charts $N_CHARTS \
--out_path $OUT_PATH # if not using Weights & Biases for tracking
where
$DATA_ROOT
is the path to the data (e.g.data/CUB_200_2011/images/
ordata/cars/
), which should be a folder of subfolders, where each subfolder has the images for one class.$DIM_Z
and$N_CHARTS
are the dimension and number of charts, respectively, for the manifold.$OUT_PATH
is the path to save model checkpoints and tensorboard output.
Citation
@inproceedings{
korman2021selfsupervised,
title={Self-supervised representation learning on manifolds},
author={Eric O Korman},
booktitle={ICLR 2021 Workshop on Geometrical and Topological Representation Learning},
year={2021},
url={https://openreview.net/forum?id=EofGDIGAhvR}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file neurve-0.1.0.tar.gz
.
File metadata
- Download URL: neurve-0.1.0.tar.gz
- Upload date:
- Size: 13.7 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 37e83d7e27ba3245718319ca6e322994d2616f40577c5abd2f6a344da02906d1 |
|
MD5 | 12f3e76ac4ba97b4add8b143f0dbbc88 |
|
BLAKE2b-256 | 58f0115f609b83f77e1bb6959efed7992124c0790dad0b2ee4129202de7ac0d8 |
File details
Details for the file neurve-0.1.0-py3-none-any.whl
.
File metadata
- Download URL: neurve-0.1.0-py3-none-any.whl
- Upload date:
- Size: 18.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4abe5409d81b6ec2a70a1ecb49f2ab51ec0e85da8684b6a4cb084ed75c994d52 |
|
MD5 | ff72f4a42c7087b8cbda6a9bf8f7f6bd |
|
BLAKE2b-256 | 8d09f4d4cf2f58e281a3e92b0dc534ee24490945096a7244ae254d6fc8e63a3c |