Sparsemax pytorch
Project description
sparsemax
A PyTorch implementation of SparseMax (https://arxiv.org/pdf/1602.02068.pdf) with gradients checked and tested
Sparsemax is an alternative to softmax when one wants to generate hard probability distributions. It has been used to great effect in recent papers like ProtoAttend (https://arxiv.org/pdf/1902.06292v4.pdf).
Installation
pip install -U sparsemax
Usage
Use as if it was nn.Softmax()
! Nice and simple.
from sparsemax import Sparsemax
import torch
import torch.nn as nn
sparsemax = Sparsemax(dim=-1)
softmax = torch.nn.Softmax(dim=-1)
logits = torch.randn(2, 3, 5)
logits.requires_grad = True
print("\nLogits")
print(logits)
softmax_probs = softmax(logits)
print("\nSoftmax probabilities")
print(softmax_probs)
sparsemax_probs = sparsemax(logits)
print("\nSparsemax probabilities")
print(sparsemax_probs)
Advantages over existing implementations
This repo borrows heavily from: https://github.com/KrisKorrel/sparsemax-pytorch
However, there are a few key advantages:
Backward pass equations implemented natively as a
torch.autograd.Function
, resulting in 30% speedup, compared to the above repository.The package is easily pip-installable (no need to copy the code).
The package works for multi-dimensional tensors, operating over any axis.
The operator forward and backward passes are tested (backward-pass check due to
torch.autograd.gradcheck
Check that gradients are computed correctly
from torch.autograd import gradcheck
from sparsemax import Sparsemax
input = (torch.randn(6, 3, 20,dtype=torch.double,requires_grad=True))
test = gradcheck(sparsemax, input, eps=1e-6, atol=1e-4)
print(test)
Credits
This package was created with Cookiecutter and the audreyr/cookiecutter-pypackage project template.
History
0.1.0 (2020-05-25)
First release on PyPI.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file sparsemax-0.1.9.tar.gz
.
File metadata
- Download URL: sparsemax-0.1.9.tar.gz
- Upload date:
- Size: 12.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.4.0 requests-toolbelt/0.9.1 tqdm/4.46.0 CPython/3.8.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 85fe08d08900cbf2a0259e7925f88f59e7fe725e8981236c8b14e239b47f0f17 |
|
MD5 | 0d46faeeea512d64e790e354e31d7852 |
|
BLAKE2b-256 | d44afe026840c0b6a7dca0741d9bdadc9c86fa132e21573679c3544bc35c0812 |
File details
Details for the file sparsemax-0.1.9-py2.py3-none-any.whl
.
File metadata
- Download URL: sparsemax-0.1.9-py2.py3-none-any.whl
- Upload date:
- Size: 5.5 kB
- Tags: Python 2, Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.4.0 requests-toolbelt/0.9.1 tqdm/4.46.0 CPython/3.8.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2e7191933652dea3df223079efdd871e871440fd74d7edd7ba318c34f707e0bd |
|
MD5 | 4709604f33d368c7fd72a07e82f26886 |
|
BLAKE2b-256 | 1cf8e56723d8279ff156dea120c67afde88be80448958bb88d5307426390794f |