Draw a large number of samples from a categorical distribution with large support on the GPU using Pytorch.
Project description
Fast Sampling from Categorical Distributions on the GPU using PyTorch
Currently, the pytorch.distributions.Categorical
is a bit slow if you
need to draw a large number of samples from a static categorical distribution.
Also, you are limited to having no more than 2^24 different outcomes.
If you need categorical distributions with really large support, and/or you need to quickly draw millions of samples, then this is for you.
At the time I made this, there was an open issue to incorporate a more rapid sampler based on the alias method (which I use here). Hopefully that will get into a release soon! For now, use this!
I've backed this by a few simple tests, including a benchmark against torch. This implementation takes about 100X longer to construct a sampler, but after this up-front cost, it yields samples about 3500X faster. So the main usecase is when you have to draw many samples from a stable distribution.
Enjoy!
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Hashes for pytorch-categorical-0.0.1.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | df34d1efddae20e26869a713e6b4c5f013f7a48e3af31050351206098f7c57c1 |
|
MD5 | 8e4851e01c259a5f130c877f43f622f0 |
|
BLAKE2b-256 | 4d7e71e0ba8e8375b1380bbaddbd00ca3f040098f62402e2fc04f331c4582ab3 |