Draw a large number of samples from a categorical distribution with large support on the GPU using Pytorch.
Project description
Fast Sampling from Categorical Distributions on the GPU using PyTorch
Currently, the pytorch.distributions.Categorical
is a bit slow if you
need to draw a large number of samples from a static categorical distribution.
Also, you are limited to having no more than 2^24 different outcomes.
The "alias method" let's you sample very quickly from distributions with large support, and this implementation in PyTorch let's you have more than 2^24 outcomes.
I needed this for rapid generation of word embeddings in hilbert.
Install
pip install pytorch-categorical
Use
import pytorch_categorical
import torch
num_outcomes = int(1e6)
probs = torch.random(num_outcomes)
probs /= probs.sum()
sampler = pytorch_categorical.Categorical(probs)
num_samples = int(1e6)
samples = sampler.sample((num_samples,))
The constructor also takes a dtype
and a device
if you want to specify
them. By default
Posterity
At the time I made this, there was an open issue to incorporate a more rapid sampler based on the alias method, but nothing was released yet. Hopefully that will get into a release soon! For now, use this!
Tested. It's Correct and Fast.
I've backed this by a few simple tests, including a benchmark against torch. This implementation takes about 175X longer to construct a sampler with one million outcomes, but after this up-front cost, it yields (draws of ten thousand) samples about 3000X faster (with greater advantage the more samples that are eventually drawn). So the main usecase is when you have to draw many samples from a stable distribution.
Run the correctness and benchmark tests: python test.py
.
Enjoy!
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Hashes for pytorch-categorical-0.0.3.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | d71c6d93da62cab4d2bd57d1ea58105ba1354bf14e85a9b94ca4e9c506bf93cb |
|
MD5 | 82176550a53d5c2f6cae3700b680fd09 |
|
BLAKE2b-256 | b0b88b82b5c66bef1fdfd8b6c5241fc175cf5e9d9ac33fec7aebeb8dda796177 |