A Python library for discrete variables relaxation
Project description
Just Relax It
Discrete Variables Relaxation
"Just Relax It" is a cutting-edge Python library designed to streamline the optimization of discrete probability distributions in neural networks, offering a suite of advanced relaxation techniques compatible with PyTorch.
📬 Assets
- Technical Meeting 1 - Presentation
- Technical Meeting 2 - Jupyter Notebook
- Technical Meeting 3 - Jupyter Notebook
- Blog Post
- Documentation
- Tests
💡 Motivation
For lots of mathematical problems we need an ability to sample discrete random variables. The problem is that due to continuos nature of deep learning optimization, the usage of truely discrete random variables is infeasible. Thus we use different relaxation method. One of them, Concrete distribution or Gumbel-Softmax (this is one distribution proposed in parallel by two research groups) is implemented in different DL packages. In this project we implement different alternatives to it.
🗃 Algorithms
- Relaxed Bernoulli, also see 📝 paper
- Correlated relaxed Bernoulli, also see 📝 paper
- Gumbel-softmax TOP-K, also see 📝 paper
- Straight-Through Bernoulli, also see 📝 paper
- Invertible Gaussian with KL implemented, also see 📝 paper
- Hard Concrete, also see 📝 paper
- REINFORCE, also see 📺 slides
- Logit-Normal and Laplace-form approximation of Dirichlet, also see ℹ️ wiki and 💻 stackexchange
🛠️ Install
Install using pip
pip install relaxit
Install from source
pip install git+https://github.com/intsystems/discrete-variables-relaxation
Install via Git clone
git clone https://github.com/intsystems/discrete-variables-relaxation
cd discrete-variables-relaxation
pip install -e .
🚀 Quickstart
import torch
from relaxit.distributions import InvertibleGaussian
# initialize distribution parameters
loc = torch.zeros(3, 4, 5, requires_grad=True)
scale = torch.ones(3, 4, 5, requires_grad=True)
temperature = torch.tensor([1e-0])
# initialize distribution
distribution = InvertibleGaussian(loc, scale, temperature)
# sample with reparameterization
sample = distribution.rsample()
print('sample.shape:', sample.shape)
print('sample.requires_grad:', sample.requires_grad)
🎮 Demo
| Laplace Bridge | REINFORCE in Acrobot environment | VAE with discrete latents |
|---|---|---|
For demonstration purposes, we divide our algorithms in three different groups. Each group relates to the particular demo code:
- Laplace bridge between Dirichlet and LogisticNormal distributions
- REINFORCE
- Other relaxation methods
We describe our demo experiments here.
📚 Stack
Some of the alternatives for GS were implemented in pyro, so we base our library on their codebase.
🧩 Some details
To make to library consistent, we integrate imports of distributions from pyro and torch into the library, so that all the categorical distributions can be imported from one entrypoint.
👥 Contributors
- Daniil Dorin (Basic code writing, Final demo, Algorithms)
- Igor Ignashin (Project wrapping, Documentation writing, Algorithms)
- Nikita Kiselev (Project planning, Blog post, Algorithms)
- Andrey Veprikov (Tests writing, Documentation writing, Algorithms)
- You are welcome to contribute to our project!
🔗 Useful links
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file relaxit-1.0.0.tar.gz.
File metadata
- Download URL: relaxit-1.0.0.tar.gz
- Upload date:
- Size: 10.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.1 CPython/3.12.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9511ea7055332b1f644fcb0efe178fc772f0bbff4a75df9503b5fefb7965d76c
|
|
| MD5 |
f9ad2a4a057624586855ce6af3886a9a
|
|
| BLAKE2b-256 |
538a49afc590f665439654a47df9bb0ded998e236cee70b60431f396403814ca
|
Provenance
The following attestation bundles were made for relaxit-1.0.0.tar.gz:
Publisher:
pypi.yml on intsystems/discrete-variables-relaxation
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
relaxit-1.0.0.tar.gz -
Subject digest:
9511ea7055332b1f644fcb0efe178fc772f0bbff4a75df9503b5fefb7965d76c - Sigstore transparency entry: 152974845
- Sigstore integration time:
-
Permalink:
intsystems/discrete-variables-relaxation@1165c0553f4605af86d6e691fb2ba1d76839dd71 -
Branch / Tag:
refs/tags/v1.0.0 - Owner: https://github.com/intsystems
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
pypi.yml@1165c0553f4605af86d6e691fb2ba1d76839dd71 -
Trigger Event:
push
-
Statement type:
File details
Details for the file relaxit-1.0.0-py3-none-any.whl.
File metadata
- Download URL: relaxit-1.0.0-py3-none-any.whl
- Upload date:
- Size: 16.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.1 CPython/3.12.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d27bb0e46294b7a7a600cba909790bd445c692526f6c5d5a62bedbe28b3487d6
|
|
| MD5 |
934d5dae51a67c0ca43d4dedc4ce6452
|
|
| BLAKE2b-256 |
114f750b90830837d4430fdf706fa1b579197956c22ac373542990bd3c8c413f
|
Provenance
The following attestation bundles were made for relaxit-1.0.0-py3-none-any.whl:
Publisher:
pypi.yml on intsystems/discrete-variables-relaxation
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
relaxit-1.0.0-py3-none-any.whl -
Subject digest:
d27bb0e46294b7a7a600cba909790bd445c692526f6c5d5a62bedbe28b3487d6 - Sigstore transparency entry: 152974846
- Sigstore integration time:
-
Permalink:
intsystems/discrete-variables-relaxation@1165c0553f4605af86d6e691fb2ba1d76839dd71 -
Branch / Tag:
refs/tags/v1.0.0 - Owner: https://github.com/intsystems
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
pypi.yml@1165c0553f4605af86d6e691fb2ba1d76839dd71 -
Trigger Event:
push
-
Statement type: