Skip to main content

Sparsemax pytorch

Project description

sparsemax

https://img.shields.io/pypi/v/sparsemax.svg https://img.shields.io/travis/aced125/sparsemax.svg Documentation Status Updates coverage.svg

A PyTorch implementation of SparseMax (https://arxiv.org/pdf/1602.02068.pdf) with gradients checked and tested

Sparsemax is an alternative to softmax when one wants to generate hard probability distributions. It has been used to great effect in recent papers like ProtoAttend (https://arxiv.org/pdf/1902.06292v4.pdf).

Installation

pip install -U sparsemax

Usage

Use as if it was nn.Softmax()! Nice and simple.

from sparsemax import Sparsemax
import torch
import torch.nn as nn

sparsemax = Sparsemax(dim=-1)
softmax = torch.nn.Softmax(dim=-1)

logits = torch.randn(2, 3, 5)
logits.requires_grad = True
print("\nLogits")
print(logits)

softmax_probs = softmax(logits)
print("\nSoftmax probabilities")
print(softmax_probs)

sparsemax_probs = sparsemax(logits)
print("\nSparsemax probabilities")
print(sparsemax_probs)

Advantages over existing implementations

This repo borrows heavily from: https://github.com/KrisKorrel/sparsemax-pytorch

However, there are a few key advantages:

  1. Backward pass equations implemented natively as a torch.autograd.Function, resulting in 30% speedup, compared to the above repository.

  2. The package is easily pip-installable (no need to copy the code).

  3. The package works for multi-dimensional tensors, operating over any axis.

  4. The operator forward and backward passes are tested (backward-pass check due to torch.autograd.gradcheck

Check that gradients are computed correctly

from torch.autograd import gradcheck
from sparsemax import Sparsemax

input = (torch.randn(6, 3, 20,dtype=torch.double,requires_grad=True))
test = gradcheck(sparsemax, input, eps=1e-6, atol=1e-4)
print(test)

Credits

This package was created with Cookiecutter and the audreyr/cookiecutter-pypackage project template.

History

0.1.0 (2020-05-25)

  • First release on PyPI.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sparsemax-0.1.9.tar.gz (12.1 kB view details)

Uploaded Source

Built Distribution

sparsemax-0.1.9-py2.py3-none-any.whl (5.5 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file sparsemax-0.1.9.tar.gz.

File metadata

  • Download URL: sparsemax-0.1.9.tar.gz
  • Upload date:
  • Size: 12.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.4.0 requests-toolbelt/0.9.1 tqdm/4.46.0 CPython/3.8.0

File hashes

Hashes for sparsemax-0.1.9.tar.gz
Algorithm Hash digest
SHA256 85fe08d08900cbf2a0259e7925f88f59e7fe725e8981236c8b14e239b47f0f17
MD5 0d46faeeea512d64e790e354e31d7852
BLAKE2b-256 d44afe026840c0b6a7dca0741d9bdadc9c86fa132e21573679c3544bc35c0812

See more details on using hashes here.

File details

Details for the file sparsemax-0.1.9-py2.py3-none-any.whl.

File metadata

  • Download URL: sparsemax-0.1.9-py2.py3-none-any.whl
  • Upload date:
  • Size: 5.5 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.4.0 requests-toolbelt/0.9.1 tqdm/4.46.0 CPython/3.8.0

File hashes

Hashes for sparsemax-0.1.9-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 2e7191933652dea3df223079efdd871e871440fd74d7edd7ba318c34f707e0bd
MD5 4709604f33d368c7fd72a07e82f26886
BLAKE2b-256 1cf8e56723d8279ff156dea120c67afde88be80448958bb88d5307426390794f

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page