atlas learning
Project description
neurve
This is the repository to accompany the paper Self-supervised representation learning on manifolds, to be presented at the ICLR 2021 Workshop on Geometrical and Topological Representation Learning.
Additionally, we implement a manifold version of triplet training, which will be expounded on in an upcoming preprint.
Notebooks
This notebook will run inference using a pre-trained Manifold SimCLR model (trained on either CIFAR10, FashionMNIST, or MNIST).
Installation
Install via
pip install neurve
or, to install with Weights & Biases support, run:
pip install "neurve[wandb]"
You can also install from source by cloning this repository and then running, from the repo root, the command
pip install . # or pip install .[wandb]
The dependencies are
numpy>=1.17.4
torch>=1.3.1
torchvision>=0.4.2
scipy>=1.5.3 (for parsing the cars dataset annotations)
tqdm
tensorboardX
Datasets
To get the datasets for metric learning (the datasets we use for representation learning are included in torchvision.datasets):
- CUB dataset: Download the file
CUB_200_2011.tgzfrom http://www.vision.caltech.edu/visipedia/CUB-200-2011.html and decompress in thedatafolder. The folder structure should bedata/CUB_200_2011/images/. - cars196 dataset: run
make data/cars.
Training commands
Tracking with Weights & Biases
To use Weights & Biases to log training/validation metrics and for storing model checkpoints, set the environment variable NEURVE_TRACKER to wandb. Otherwise tensorboardX will be used for metric logging and model checkpoints will be saved locally.
Manifold SimCLR
For self-supervised training, run the command
python experiments/simclr.py \
--dataset $DATASET \
--backbone $BACKBONE \
--dim_z $DIM_Z \
--n_charts $N_CHARTS \
--n_epochs $N_EPOCHS \
--tau $TAU \
--out_path $OUT_PATH # if not using Weights & Biases for tracking
where
$DATASETis one of"cifar","mnist","fashion_mnist".$BACKBONEis the name of the backbone network (in the paper we used"resnet50"for CIFAR10 and"resnet18"for MNIST and FashionMNIST).$DIM_Zand$N_CHARTSare the dimension and number of charts, respectively, for the manifold.$N_EPOCHSis the number of epochs to train for (in the paper we used 1,000 for CIFAR10 and 100 for MNIST and FashionMNIST).$TAUis the temperature parameter for the contrastive loss function (in the paper we used 0.5 for CIFAR10 and 1.0 for MNIST and FashionMNIST).$OUT_PATHis the path to save model checkpoints and tensorboard output.
Manifold metric learning
To train metric learning, run the command
python experiments/triplet.py \
--data_root $DATA_ROOT \
--dim_z $DIM_Z \
--n_charts $N_CHARTS \
--out_path $OUT_PATH # if not using Weights & Biases for tracking
where
$DATA_ROOTis the path to the data (e.g.data/CUB_200_2011/images/ordata/cars/), which should be a folder of subfolders, where each subfolder has the images for one class.$DIM_Zand$N_CHARTSare the dimension and number of charts, respectively, for the manifold.$OUT_PATHis the path to save model checkpoints and tensorboard output.
Citation
@inproceedings{
korman2021selfsupervised,
title={Self-supervised representation learning on manifolds},
author={Eric O Korman},
booktitle={ICLR 2021 Workshop on Geometrical and Topological Representation Learning},
year={2021},
url={https://openreview.net/forum?id=EofGDIGAhvR}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file neurve-0.1.0.tar.gz.
File metadata
- Download URL: neurve-0.1.0.tar.gz
- Upload date:
- Size: 13.7 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
37e83d7e27ba3245718319ca6e322994d2616f40577c5abd2f6a344da02906d1
|
|
| MD5 |
12f3e76ac4ba97b4add8b143f0dbbc88
|
|
| BLAKE2b-256 |
58f0115f609b83f77e1bb6959efed7992124c0790dad0b2ee4129202de7ac0d8
|
File details
Details for the file neurve-0.1.0-py3-none-any.whl.
File metadata
- Download URL: neurve-0.1.0-py3-none-any.whl
- Upload date:
- Size: 18.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4abe5409d81b6ec2a70a1ecb49f2ab51ec0e85da8684b6a4cb084ed75c994d52
|
|
| MD5 |
ff72f4a42c7087b8cbda6a9bf8f7f6bd
|
|
| BLAKE2b-256 |
8d09f4d4cf2f58e281a3e92b0dc534ee24490945096a7244ae254d6fc8e63a3c
|