Skip to main content

Sparsemax pytorch

Project description

sparsemax

https://img.shields.io/pypi/v/sparsemax.svg https://img.shields.io/travis/aced125/sparsemax.svg Documentation Status Updates coverage.svg

A PyTorch implementation of SparseMax (https://arxiv.org/pdf/1602.02068.pdf) with gradients checked and tested

Sparsemax is an alternative to softmax when one wants to generate hard probability distributions. It has been used to great effect in recent papers like ProtoAttend (https://arxiv.org/pdf/1902.06292v4.pdf).

Installation

pip install -U sparsemax

Usage

Use as if it was nn.Softmax()! Nice and simple.

from sparsemax import Sparsemax
import torch
import torch.nn as nn

sparsemax = Sparsemax(dim=-1)
softmax = torch.nn.Softmax(dim=-1)

logits = torch.randn(2, 3, 5)
logits.requires_grad = True
print("\nLogits")
print(logits)

softmax_probs = softmax(logits)
print("\nSoftmax probabilities")
print(softmax_probs)

sparsemax_probs = sparsemax(logits)
print("\nSparsemax probabilities")
print(sparsemax_probs)

Advantages over existing implementations

This repo borrows heavily from: https://github.com/KrisKorrel/sparsemax-pytorch

However, there are a few key advantages:

  1. Backward pass equations implemented natively as a torch.autograd.Function, resulting in 30% speedup, compared to the above repository.
  2. The package is easily pip-installable (no need to copy the code).
  3. The package works for multi-dimensional tensors, operating over any axis.
  4. The operator forward and backward passes are tested (backward-pass check due to torch.autograd.gradcheck

Check that gradients are computed correctly

from torch.autograd import gradcheck
from sparsemax import Sparsemax

input = (torch.randn(6, 3, 20,dtype=torch.double,requires_grad=True))
test = gradcheck(sparsemax, input, eps=1e-6, atol=1e-4)
print(test)

Credits

This package was created with Cookiecutter and the audreyr/cookiecutter-pypackage project template.

History

0.1.0 (2020-05-25)

  • First release on PyPI.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sparsemax-0.1.9.tar.gz (12.1 kB view hashes)

Uploaded source

Built Distribution

sparsemax-0.1.9-py2.py3-none-any.whl (5.5 kB view hashes)

Uploaded py2 py3

Supported by

AWS AWS Cloud computing Datadog Datadog Monitoring Facebook / Instagram Facebook / Instagram PSF Sponsor Fastly Fastly CDN Google Google Object Storage and Download Analytics Huawei Huawei PSF Sponsor Microsoft Microsoft PSF Sponsor NVIDIA NVIDIA PSF Sponsor Pingdom Pingdom Monitoring Salesforce Salesforce PSF Sponsor Sentry Sentry Error logging StatusPage StatusPage Status page