Skip to main content

Draw a large number of samples from a categorical distribution with large support on the GPU using Pytorch.

Project description

Fast Sampling from Categorical Distributions on the GPU using PyTorch

Currently, the pytorch.distributions.Categorical is a bit slow if you need to draw a large number of samples from a static categorical distribution. Also, you are limited to having no more than 2^24 different outcomes.

The "alias method" let's you sample very quickly from distributions with large support, and this implementation in PyTorch let's you have more than 2^24 outcomes.

I needed this for rapid generation of word embeddings in hilbert.

Install

pip install pytorch-categorical

Use

import pytorch_categorical
import torch

num_outcomes = int(1e6)
probs = torch.random(num_outcomes)
probs /= probs.sum()

sampler = pytorch_categorical.Categorical(probs)

num_samples = int(1e6)
samples = sampler.sample((num_samples,))

The constructor also takes a dtype and a device if you want to specify them. By default

Posterity

At the time I made this, there was an open issue to incorporate a more rapid sampler based on the alias method, but nothing was released yet. Hopefully that will get into a release soon! For now, use this!

Tested. It's Correct and Fast.

I've backed this by a few simple tests, including a benchmark against torch. This implementation takes about 175X longer to construct a sampler with one million outcomes, but after this up-front cost, it yields (draws of ten thousand) samples about 3000X faster (with greater advantage the more samples that are eventually drawn). So the main usecase is when you have to draw many samples from a stable distribution.

Run the correctness and benchmark tests: python test.py.

Enjoy!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch-categorical-0.0.3.tar.gz (4.3 kB view hashes)

Uploaded Source

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page