Skip to main content

Stochastic Deep Learning for PyTorch

Project description

Stochastic Deep Learning for Pytorch

Documentation Status

Documentation on Read the Docs. Storchastic is a PyTorch library for stochastic gradient estimation in Deep Learning [1]. Many state of the art deep learning models use gradient estimation, in particular within the fields of Variational Inference and Reinforcement Learning. While PyTorch computes gradients of deterministic computation graphs automatically, it will not estimate gradients on stochastic computation graphs [2].

With Storchastic, you can easily define any stochastic deep learning model and let it estimate the gradients for you. Storchastic provides a large range of gradient estimation methods that you can plug and play, to figure out which one works best for your problem. Storchastic provides automatic broadcasting of sampled batch dimensions, which increases code readability and allows implementing complex models with ease.

When dealing with continuous random variables and differentiable functions, the popular reparameterization method [3] is usually very effective. However, this method is not applicable when dealing with discrete random variables or non-differentiable functions. This is why Storchastic has a focus on gradient estimators for discrete random variables, non-differentiable functions and sequence models.

Documentation on Read the Docs.

Example: Discrete Variational Auto-Encoder

Installation

pip install storchastic

Requires Pytorch 1.5 (older versions will not do!) and Pyro. The code is build on Python 3.7. The master branch works with PyTorch 1.7, but the version on pip is not compatible. Binaries will be updated soon.

Algorithms

Feel free to create an issue if an estimator is missing here.

  • Reparameterization [1, 3]
  • Score Function (REINFORCE) with Moving Average baseline [1, 4]
  • Score Function with Batch Average Baseline [5, 6]
  • Expected value for enumerable distributions
  • (Straight through) Gumbel Softmax [7, 8]
  • LAX, RELAX [9]
  • REBAR [10]
  • REINFORCE Without Replacement [6]
  • Unordered Set Estimator [13]

In development

  • Memory Augmented Policy Optimization [11]
  • Rao-Blackwellized REINFORCE [12]

Planned

  • Measure valued derivatives [1, 14]
  • ARM [15]
  • Automatic Credit Assignment [16]
  • ...

References

Cite

To cite Storchastic, please cite this preprint:

@article{van2021storchastic,
  title={Storchastic: A Framework for General Stochastic Automatic Differentiation},
  author={van Krieken, Emile and Tomczak, Jakub M and Teije, Annette ten},
  journal={arXiv preprint arXiv:2104.00428},
  year={2021}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

storchastic-0.3.1.tar.gz (67.4 kB view details)

Uploaded Source

File details

Details for the file storchastic-0.3.1.tar.gz.

File metadata

  • Download URL: storchastic-0.3.1.tar.gz
  • Upload date:
  • Size: 67.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.6.1 requests/2.25.0 setuptools/51.0.0 requests-toolbelt/0.9.1 tqdm/4.53.0 CPython/3.8.5

File hashes

Hashes for storchastic-0.3.1.tar.gz
Algorithm Hash digest
SHA256 99bcab78192b73f4602fb7a9991bd49e38e18ab3f7a058c39274c01f7f37810c
MD5 07e61d010e9d4f7d55ce8a9329365edf
BLAKE2b-256 6f3cd908331da309712f50c5b0226105e770342091f6238d697994eadfda223f

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page