Skip to main content

Relaxit: A Python Library for Optimizing Discrete Probability Distributions in Neural Networks

Project description

Just Relax It

Just Relax It

Discrete Variables Relaxation

Compatible with PyTorch Inspired by Pyro

Coverage_2 Coverage Docs

License GitHub Contributors Issues GitHub Pull Requests

"Just Relax It" is a cutting-edge Python library designed to streamline the optimization of discrete probability distributions in neural networks, offering a suite of advanced relaxation techniques compatible with PyTorch.

📬 Assets

  1. Technical Meeting 1 - Presentation
  2. Technical Meeting 2 - Jupyter Notebook
  3. Technical Meeting 3 - Jupyter Notebook
  4. Blog Post 1
  5. Blog Post 2
  6. Documentation
  7. Tests
  8. Technical Report

💡 Motivation

For lots of mathematical problems we need an ability to sample discrete random variables. The problem is that due to continuous nature of deep learning optimization, the usage of truly discrete random variables is infeasible. Thus we use different relaxation methods. One of them, Concrete distribution or Gumbel-Softmax (this is one distribution proposed in parallel by two research groups) is implemented in different DL packages. In this project we implement different alternatives to it.

🗃 Algorithms

🛠️ Install Using uv (Recommended)

For Production

uv pip install relaxit

For Development

git clone https://github.com/intsystems/relaxit
cd relaxit
uv venv                    # create venv
source .venv/bin/activate  # activate venv
uv sync                    # install all the dependencies
uv pip install -e .        # make the relaxit package editable

To run tests:

uv run pytest tests/

To run Python scripts:

uv run python demo/vae_hard_concrete.py

To run notebooks:

uv run jupyter lab

⚒️ Install Using pip

For Production

pip install -r requirements.txt

For Development

pip install -r requirements-dev.txt

🚀 Quickstart

Open In Colab

import torch
from relaxit.distributions import InvertibleGaussian

# initialize distribution parameters
loc = torch.zeros(3, 4, 5, requires_grad=True)
scale = torch.ones(3, 4, 5, requires_grad=True)
temperature = torch.tensor([1e-0])

# initialize distribution
distribution = InvertibleGaussian(loc, scale, temperature)

# sample with reparameterization
sample = distribution.rsample()
print('sample.shape:', sample.shape)
print('sample.requires_grad:', sample.requires_grad)

🎮 Demo

Laplace Bridge REINFORCE in Acrobot environment VAE with discrete latents
Laplace Bridge REINFORCE VAE
Open In Colab Open In Colab Open In Colab

For demonstration purposes, we divide our algorithms in three[^*] different groups. Each group relates to the particular demo code:

We describe our demo experiments here.

[^*]: We also implement REINFORCE algorithm as a score function estimator alternative for our relaxation methods that are inherently pathwise derivative estimators. This one is implemented only for demo experiments and is not included into the source code of package.

📚 Stack

Some of the alternatives for GS were implemented in pyro, so we base our library on their codebase.

🧩 Some details

To make to library consistent, we integrate imports of distributions from pyro and torch into the library, so that all the categorical distributions can be imported from one entrypoint.

👥 Contributors

🔗 Useful links

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

relaxit-1.2.1.tar.gz (25.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

relaxit-1.2.1-py3-none-any.whl (36.3 kB view details)

Uploaded Python 3

File details

Details for the file relaxit-1.2.1.tar.gz.

File metadata

  • Download URL: relaxit-1.2.1.tar.gz
  • Upload date:
  • Size: 25.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for relaxit-1.2.1.tar.gz
Algorithm Hash digest
SHA256 05a53e2fe69b5d576ebe96e89856ed4d37e9c57543c9af6d19137e422ed60aca
MD5 d13f546edc9cfd312606008fbfb012cd
BLAKE2b-256 5b075357b449ea8ef54d4ac44db282098477b9954d1f1f795ea3811d19ffada3

See more details on using hashes here.

Provenance

The following attestation bundles were made for relaxit-1.2.1.tar.gz:

Publisher: pypi.yml on intsystems/relaxit

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file relaxit-1.2.1-py3-none-any.whl.

File metadata

  • Download URL: relaxit-1.2.1-py3-none-any.whl
  • Upload date:
  • Size: 36.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for relaxit-1.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 91e2bf1522abb8060bbe6f41ac3ffe198d4c54ed52d23e0d5532b7c087c1430d
MD5 f727390820263d7512891d1171793aa0
BLAKE2b-256 393e6505b4440e04aaa15f0e886212eb32622a0a9f121f1b84bf3c38c023de1c

See more details on using hashes here.

Provenance

The following attestation bundles were made for relaxit-1.2.1-py3-none-any.whl:

Publisher: pypi.yml on intsystems/relaxit

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page