Skip to main content

Draw a large number of samples from a categorical distribution with large support on the GPU using Pytorch.

Project description

Fast Sampling from Categorical Distributions on the GPU using PyTorch

Currently, the pytorch.distributions.Categorical is a bit slow if you need to draw a large number of samples from a static categorical distribution. Also, you are limited to having no more than 2^24 different outcomes.

If you need categorical distributions with really large support, and/or you need to quickly draw millions of samples, then this is for you.

At the time I made this, there was an open issue to incorporate a more rapid sampler based on the alias method (which I use here). Hopefully that will get into a release soon! For now, use this!

I've backed this by a few simple tests, including a benchmark against torch. This implementation takes about 100X longer to construct a sampler, but after this up-front cost, it yields samples about 3500X faster. So the main usecase is when you have to draw many samples from a stable distribution.

Enjoy!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch-categorical-0.0.1.tar.gz (3.7 kB view hashes)

Uploaded Source

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page