Sparsemax pytorch
Project description
sparsemax
A PyTorch implementation of SparseMax (https://arxiv.org/pdf/1602.02068.pdf) with gradients checked and tested
Free software: MIT license
Documentation: https://sparsemax.readthedocs.io.
Installation
pip install -U sparsemax
Usage
Use as if it was nn.Softmax()
! Nice and simple.
from sparsemax import Sparsemax
import torch
import torch.nn as nn
sparsemax = Sparsemax(dim=-1)
softmax = torch.nn.Softmax(dim=-1)
logits = torch.randn(2, 3, 5)
logits.requires_grad = True
print("\nLogits")
print(logits)
softmax_probs = softmax(logits)
print("\nSoftmax probabilities")
print(softmax_probs)
sparsemax_probs = sparsemax(logits)
print("\nSparsemax probabilities")
print(sparsemax_probs)
Check that gradients are computed correctly
from torch.autograd import gradcheck
from sparsemax import Sparsemax
input = (torch.randn(6, 3, 20,dtype=torch.double,requires_grad=True))
test = gradcheck(sparsemax, input, eps=1e-6, atol=1e-4)
print(test)
Credits
This package was created with Cookiecutter and the audreyr/cookiecutter-pypackage project template.
History
0.1.0 (2020-05-25)
First release on PyPI.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
sparsemax-0.1.6.tar.gz
(11.3 kB
view hashes)
Built Distribution
Close
Hashes for sparsemax-0.1.6-py2.py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | deec73f605657754b16a32495f16d7eeec099f9afda0f8dc1ed3b91d269515eb |
|
MD5 | c8f66008e0a91d6c8fc6111ab769c98c |
|
BLAKE2b-256 | 0d11448d20bda9b2abd088febf41a6c2a59673fbd3e83e3dd62c4ebca2e90037 |