Skip to main content

SESaMo provides an extension to Normalizing Flows that enforces symmetries to the output distribution.

Project description

Paper License: MIT

SESaMo: Symmetry-Enforcing Stochastic Modulation for Normalizing Flows

Quick start

Install the package with pip:

pip install sesamo

Here is a quick example of how to use SESaMo to build a normalizing flow with stochastic modulation:

import torch
from sesamo import Sesamo
from sesamo.models import GaussianPrior, RealNVP, Z2Modulation, Z2Regularization
from sesamo.loss import StochmodLoss

# Initialize SESaMo
sesamo = Sesamo(
    prior=GaussianPrior(
        var=1,
        lat_shape=[1,2]
    ),
    flow=RealNVP(
        lat_shape=[1,2],
        num_coupling_layers=10,
        num_hidden_layers=2,
        num_hidden_features=40
    ),
    stochastic_modulation=Z2Modulation(),
    regularization=Z2Regularization(),
).to("cuda")

action = # define action for the target distribution p(x) = exp(-action(x)) / Z
loss_fn = StochmodLoss()
optimizer = torch.optim.Adam(sesamo.parameters(), lr=5e-4)

# Training loop
for _ in range(10_000):
    # reset gradients
    optimizer.zero_grad()

    # sample from sesamo
    samples, log_prob, log_prob_stochmod, penalty = sesamo.sample_for_training(8_000)
    
    # compute action and loss
    action_samples = action(samples)
    loss = loss_fn(action_samples, log_prob, log_prob_stochmod, penalty).mean()
    
    # backpropagate and update flow parameters
    loss.backward()
    optimizer.step()

Examples

For more examples see the SESaMo/examples folder, which contains Jupyter notebooks for the Hubbard model and the Gaussian mixture model.

Run experiments

To run the experiments from the paper, follow the instructions below.

Clone the repository and move into the directory:

git clone https://github.com/janikkreit/SESaMo.git
cd SESaMo

Create a python virtual environment and install the package:

python -m venv .venv
source .venv/bin/activate
pip install -e .

Run experiments with

cd experiments
python train.py -cp configs/<experiment> -cn <model>

Available <experiment>s are:

hubbard2x1
hubbard18x100
gaussian-mixture
broken-gaussian-mixture
complex-phi4
broken-complex-phi4
broken-scalar-phi4

Available <model>s are:

realnvp
vmonf
canonicalization
sesamo

The checkpoint, tensorboard, config and stats files are stored in the SESaMo/scripts/runs folder. After training is completed or interupted the distribution is plotted and saved as SESaMo/scripts/runs/.../samples.png

Citation

If you use SESaMo in your research, please consider citing our paper:

@article{kreit2025sesamo,
    title={SESaMo: Symmetry-Enforcing Stochastic Modulation for Normalizing Flows}, 
    author={Janik Kreit and Dominic Schuh and Kim A. Nicoli and Lena Funcke},
    year={2025},
    eprint={2505.19619},
    archivePrefix={arXiv},
    primaryClass={cs.LG},
    url={https://arxiv.org/abs/2505.19619}, 
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sesamo-1.0.0.tar.gz (23.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

sesamo-1.0.0-py3-none-any.whl (27.2 kB view details)

Uploaded Python 3

File details

Details for the file sesamo-1.0.0.tar.gz.

File metadata

  • Download URL: sesamo-1.0.0.tar.gz
  • Upload date:
  • Size: 23.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.21

File hashes

Hashes for sesamo-1.0.0.tar.gz
Algorithm Hash digest
SHA256 47f6a7884e5a13162bc8adad92f37b3d640387c5b1a25d18ecb651b4501c3459
MD5 bd8874acd8b3c3ce74a25677938e2f81
BLAKE2b-256 c18c8b216b04d38a22b14a79c6293efab4e436fdf0f6ece2ffdb2b45e8a20108

See more details on using hashes here.

File details

Details for the file sesamo-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: sesamo-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 27.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.21

File hashes

Hashes for sesamo-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 8b9d850cb893e26bb67791a9bca29d2a06592472a126c95aa9a71a294cd5ffea
MD5 32650ed4e1ae74ed2ced5be9b67fcf34
BLAKE2b-256 414223ba71faa56784e596a76037978e1ac3d91b079a6f0e6ba2b39084011a48

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page