Skip to main content

Draw a large number of samples from a categorical distribution with large support on the GPU using Pytorch.

Project description

Fast Sampling from Categorical Distributions on the GPU using PyTorch

Currently, the pytorch.distributions.Categorical is a bit slow if you need to draw a large number of samples from a static categorical distribution. Also, you are limited to having no more than 2^24 different outcomes.

The "alias method" let's you sample very quickly from distributions with large support, and this implementation in PyTorch let's you have more than 2^24 outcomes.

I needed this for rapid generation of word embeddings in hilbert.

Install

pip install pytorch-categorical

Use

import pytorch_categorical
import torch

num_outcomes = int(1e6)
probs = torch.random(num_outcomes)
probs /= probs.sum()

sampler = pytorch_categorical.Categorical(probs)

num_samples = int(1e6)
samples = sampler.sample((num_samples,))

The constructor also takes a dtype and a device if you want to specify them. By default

Posterity

At the time I made this, there was an open issue to incorporate a more rapid sampler based on the alias method, but nothing was released yet. Hopefully that will get into a release soon! For now, use this!

Tested. It's Correct and Fast.

I've backed this by a few simple tests, including a benchmark against torch. This implementation takes about 175X longer to construct a sampler with one million outcomes, but after this up-front cost, it yields (draws of ten thousand) samples about 3000X faster (with greater advantage the more samples that are eventually drawn). So the main usecase is when you have to draw many samples from a stable distribution.

Run the correctness and benchmark tests: python test.py.

Enjoy!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch-categorical-0.0.3.tar.gz (4.3 kB view details)

Uploaded Source

File details

Details for the file pytorch-categorical-0.0.3.tar.gz.

File metadata

  • Download URL: pytorch-categorical-0.0.3.tar.gz
  • Upload date:
  • Size: 4.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.21.0 setuptools/40.9.0 requests-toolbelt/0.9.1 tqdm/4.31.1 CPython/3.7.1

File hashes

Hashes for pytorch-categorical-0.0.3.tar.gz
Algorithm Hash digest
SHA256 d71c6d93da62cab4d2bd57d1ea58105ba1354bf14e85a9b94ca4e9c506bf93cb
MD5 82176550a53d5c2f6cae3700b680fd09
BLAKE2b-256 b0b88b82b5c66bef1fdfd8b6c5241fc175cf5e9d9ac33fec7aebeb8dda796177

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page