Skip to main content

Sparsemax pytorch

Project description

sparsemax

https://img.shields.io/pypi/v/sparsemax.svg https://img.shields.io/travis/aced125/sparsemax.svg Documentation Status Updates

A PyTorch implementation of SparseMax (https://arxiv.org/pdf/1602.02068.pdf) with gradients checked and tested

Installation

pip install -U sparsemax

Usage

Use as if it was nn.Softmax()! Nice and simple.

from sparsemax import Sparsemax
import torch
import torch.nn as nn

sparsemax = Sparsemax(dim=-1)
softmax = torch.nn.Softmax(dim=-1)

logits = torch.randn(2, 3, 5)
logits.requires_grad = True
print("\nLogits")
print(logits)

softmax_probs = softmax(logits)
print("\nSoftmax probabilities")
print(softmax_probs)

sparsemax_probs = sparsemax(logits)
print("\nSparsemax probabilities")
print(sparsemax_probs)

Check that gradients are computed correctly

from torch.autograd import gradcheck
from sparsemax import Sparsemax

input = (torch.randn(6, 3, 20,dtype=torch.double,requires_grad=True))
test = gradcheck(sparsemax, input, eps=1e-6, atol=1e-4)
print(test)

Credits

This package was created with Cookiecutter and the audreyr/cookiecutter-pypackage project template.

History

0.1.0 (2020-05-25)

  • First release on PyPI.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sparsemax-0.1.6.tar.gz (11.3 kB view hashes)

Uploaded Source

Built Distribution

sparsemax-0.1.6-py2.py3-none-any.whl (5.0 kB view hashes)

Uploaded Python 2 Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page