Skip to main content

Differentiable Sorting, Ranking, and Top-k.

Project description

Differentiable Top-k Classification Learning

difftopk_logo

Official implementation for our ICML 2022 Paper "Differentiable Top-k Classification Learning".

The difftopk library provides different differentiable sorting and ranking methods as well as a wrapper for using them in a TopKCrossEntropyLoss. difftopk builds on PyTorch.

Paper @ ArXiv, Video @ Youtube.

💻 Installation

difftopk can be installed via pip from PyPI with

pip install difftopk

Sparse Computation

For the functionality of evaluating the differentiable topk operators in a sparse way, the package torch-sparse has to be installed. This can be done, e.g., via

pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-1.13.0+cpu.html

For more information on how to install torch-sparse, see here.

Example for Full Installation from Scratch and with all Dependencies

(click to expand)

Depending on your system, the following commands will have to be adjusted.

virtualenv -p python3 .env_topk
. .env_topk/bin/activate
pip install boto3 numpy requests scikit-learn tqdm 
pip install torch==1.13.0+cu116 torchvision==0.14.0+cu116 -f https://download.pytorch.org/whl/torch_stable.html
pip install diffsort

# optional for torch-sparse
FORCE_CUDA=1 pip install --no-cache-dir torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-1.13.0+cu116.html

pip install .

👩‍💻 Documentation

The difftopk library provides of differentiable sorting and ranking methods as well as a wrapper for using them in a TopKCrossEntropyLoss. The differentiable sorting and ranking methods included are:

  • Variants of Differentiable Sorting Networks
    • bitonic Bitonic Differentiable Sorting Networks (sparse)
    • bitonic__non_sparse Bitonic Differentiable Sorting Networks
    • splitter_selection Differentiable Splitter Selection Networks (sparse)
    • odd_even Odd-Even Differentiable Sorting Networks
  • neuralsort NeuralSort
  • softsort SoftSort

Furthermore, this library also includes the smooth top-k loss from Lapin et al. (SmoothTopKLoss and SmoothHardTopKLoss.)

TopKCrossEntropyLoss

In the center of the library lies the difftopk.TopKCrossEntropyLoss, which may be used as a drop-in replacement for torch.nn.CrossEntropyLoss. The signature of TopKCrossEntropyLoss is defined as follows:

loss_fn = difftopk.TopKCrossEntropyLoss(
    diffsort_method='odd_even',       # the sorting / ranking method as discussed above
    inverse_temperature=2,            # the inverse temperature / steepness
    p_k=[.5, 0., 0., 0., .5],         # the distribution P_K
    n=1000,                           # number of classes
    m=16,                             # the number m of scores to be sorted (can be smaller than n to make it efficient)
    distribution='cauchy',            # the distribution used for differentiable sorting networks
    art_lambda=None,                  # the lambda for the ART used if `distribution='logistic_phi'`
    device='cpu',                     # the device to compute the loss on
    top1_mode='sm'                    # makes training more stable and is the default value
)

It can be used as loss_fn(outputs, labels).

DiffTopkNet

difftopk.DiffTopkNet follows the signature of diffsort.DiffSortNet from the diffsort package. However, instead of returning the full differentiable permutation matrices of size nxn, it returns differentiable top-k attribution matrices of size nxk. More specifically, given an input of shape bxn, the module returns a tuple of None and a Tensor of shape bxnxk. (It returns a tuple to maintain the signature of DiffSortNet.)

sorter = difftopk.DiffTopkNet(
    sorting_network_type='bitonic',
    size=16,                          # Number of inputs
    k=5,                              # k
    sparse=True,                      # whether to use a sparse implementation
    device='cpu',                     # the device
    steepness=10.0,                   # the inverse temperature
    art_lambda=0.25,                  # the lambda for the ART used if `distribution='logistic_phi'`
    distribution='cauchy'             # the distribution used for the differentiable relaxation
)

# Usage example for difftopk on a random input
x = torch.randperm(16).unsqueeze(0).float() * 100.
print(x @ sorter(x)[1][0])

NeuralSort / SoftSort

sorter = difftopk.NeuralSort(
    tau=2.,     # A temperature parameter
)
sorter = difftopk.SoftSort(tau=2.)

🧪 Experiments

🧫 ImageNet Fine-Tuning

We provide pre-computed embeddings for the ImageNet data set. ⚠️ These embedding files are very large (>10 GB each.) Feel free to also use the embeddings for other fine-tuning experiments.

# Resnext101 WSL ImageNet-1K (~11GB each)
wget https://nyc3.digitaloceanspaces.com/publicdata1/ImageNet_embeddings_labels_train_test_IGAM_Resnext101_32x48d_320.p
wget https://nyc3.digitaloceanspaces.com/publicdata1/ImageNet_embeddings_labels_train_test_IGAM_Resnext101_32x32d_320.p
wget https://nyc3.digitaloceanspaces.com/publicdata1/ImageNet_embeddings_labels_train_test_IGAM_Resnext101_32x16d_320.p
wget https://nyc3.digitaloceanspaces.com/publicdata1/ImageNet_embeddings_labels_train_test_IGAM_Resnext101_32x8d_320.p

# Resnext101 WSL ImageNet-21K-P (~50GB)
wget https://publicdata1.nyc3.digitaloceanspaces.com/ImageNet21K-P_embeddings_labels_train_test_IGAM_Resnext101_32x48d_224_float16.p

# Noisy Student EfficientNet-L2 ImageNet-1K (~14GB)
wget https://publicdata1.nyc3.digitaloceanspaces.com/ImageNet_embeddings_labels_train_test_tf_efficientnet_l2_ns_timm_transform_800_float16.p

The following are the hyperparameter combinations for reproducing the tables in the paper. The DiffSortNet entries in Table 5 can be reproduced using

python experiments/train_imagenet.py  -d ./ImageNet_embeddings_labels_train_test_IGAM_Resnext101_32x48d_320.p  --nloglr 4.5 \
    --p_k .2 .2 .2 .2 .2  --m 16  --method bitonic  --distribution logistic_phi  --inverse_temperature 1.  --art_lambda .5 
python experiments/train_imagenet.py  -d ./ImageNet_embeddings_labels_train_test_IGAM_Resnext101_32x48d_320.p  --nloglr 4.5 \
    --p_k .2 .2 .2 .2 .2  --m 16  --method splitter_selection  --distribution logistic_phi  --inverse_temperature 1.  --art_lambda .5

python experiments/train_imagenet.py  -d ./ImageNet_embeddings_labels_train_test_tf_efficientnet_l2_ns_timm_transform_800_float16.p  --nloglr 4.5 \
    --p_k .25 .0 .0 .0 .75  --m 16  --method bitonic  --distribution logistic  --inverse_temperature .5 
python experiments/train_imagenet.py  -d ./ImageNet_embeddings_labels_train_test_tf_efficientnet_l2_ns_timm_transform_800_float16.p  --nloglr 4.5 \
    --p_k .25 .0 .0 .0 .75  --m 16  --method splitter_selection  --distribution logistic  --inverse_temperature .5 

and, for the remaining methods and tables, the hyperparameters are defined in the following:

(click to expand)
# Tables 2+3 (1K):
python experiments/train_imagenet.py -d ./ImageNet_embeddings_labels_train_test_IGAM_Resnext101_32x48d_320.p  --m 16  --nloglr 4.5

# Tables 2+3 (21K):
python experiments/train_imagenet.py -d ./ImageNet21K-P_embeddings_labels_train_test_IGAM_Resnext101_32x48d_224_float16.p  --m 50  --nloglr 4.  --n_epochs 40 

# combined with one of each of the following

--method softmax_cross_entropy
--method bitonic             --distribution logistic_phi  --inverse_temperature 1.  --art_lambda .5
--method splitter_selection  --distribution logistic_phi  --inverse_temperature 1.  --art_lambda .5
--method neuralsort          --inverse_temperature 0.5
--method softsort            --inverse_temperature 0.5
--method smooth_hard_topk    --inverse_temperature 1. 

--p_k 1.  0. 0. 0. 0.
--p_k 0.  0. 0. 0. 1.
--p_k .5  0. 0. 0. .5
--p_k .25 0. 0. 0. .75
--p_k .1  0. 0. 0. .9
--p_k .2  .2 .2 .2 .2

🎆 CIFAR-100 Training from Scratch

In addition to ImageNet fine-tuning, we can also train a ResNet18 on CIFAR-100 from scratch.

(click to expand)
# Tables 1+4:
python experiments/train_cifar100.py

--method softmax_cross_entropy
--method bitonic             --distribution logistic_phi  --inverse_temperature .5  --art_lambda .5
--method splitter_selection  --distribution logistic_phi  --inverse_temperature .5  --art_lambda .5
--method neuralsort          --inverse_temperature 0.0625
--method softsort            --inverse_temperature 0.0625
--method smooth_hard_topk    --inverse_temperature 1.

--p_k 1.  0. 0. 0. 0.
--p_k 0.  0. 0. 0. 1.
--p_k .5  0. 0. 0. .5
--p_k .25 0. 0. 0. .75
--p_k .1  0. 0. 0. .9
--p_k .2  .2 .2 .2 .2
  
# Examples:
python experiments/train_cifar100.py --method softmax_cross_entropy --p_k 1. 0. 0. 0. 0.
python experiments/train_cifar100.py --method splitter_selection --distribution logistic_phi --inverse_temperature .5 --art_lambda .5 --p_k .2 .2 .2 .2 .2

📖 Citing

@inproceedings{petersen2022difftopk,
  title={{Differentiable Top-k Classification Learning}},
  author={Petersen, Felix and Kuehne, Hilde and Borgelt, Christian and Deussen, Oliver},
  booktitle={International Conference on Machine Learning (ICML)},
  year={2022}
}

📜 License

difftopk is released under the MIT license. See LICENSE for additional details about it.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

difftopk-0.2.0.tar.gz (20.2 kB view details)

Uploaded Source

File details

Details for the file difftopk-0.2.0.tar.gz.

File metadata

  • Download URL: difftopk-0.2.0.tar.gz
  • Upload date:
  • Size: 20.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.8

File hashes

Hashes for difftopk-0.2.0.tar.gz
Algorithm Hash digest
SHA256 f9282a9e07bc2ff95e060af3c04d956b0d3ad9a8ef1c842bf5e535ebc880294d
MD5 74caaa7c77042d740e939a89c44ba523
BLAKE2b-256 4d5af06dd12b5548147bd70d2547c8079fbaa797d88b60847d00e20093a734e3

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page