SESaMo provides an extension to Normalizing Flows that enforces symmetries to the output distribution.
Project description
SESaMo: Symmetry-Enforcing Stochastic Modulation for Normalizing Flows
Quick start
Install the package with pip:
pip install sesamo
Here is a quick example of how to use SESaMo to build a normalizing flow with stochastic modulation:
import torch
from sesamo import Sesamo
from sesamo.models import GaussianPrior, RealNVP, Z2Modulation, Z2Regularization
from sesamo.loss import StochmodLoss
# Initialize SESaMo
sesamo = Sesamo(
prior=GaussianPrior(
var=1,
lat_shape=[1,2]
),
flow=RealNVP(
lat_shape=[1,2],
num_coupling_layers=10,
num_hidden_layers=2,
num_hidden_features=40
),
stochastic_modulation=Z2Modulation(),
regularization=Z2Regularization(),
).to("cuda")
action = # define action for the target distribution p(x) = exp(-action(x)) / Z
loss_fn = StochmodLoss()
optimizer = torch.optim.Adam(sesamo.parameters(), lr=5e-4)
# Training loop
for _ in range(10_000):
# reset gradients
optimizer.zero_grad()
# sample from sesamo
samples, log_prob, log_prob_stochmod, penalty = sesamo.sample_for_training(8_000)
# compute action and loss
action_samples = action(samples)
loss = loss_fn(action_samples, log_prob, log_prob_stochmod, penalty).mean()
# backpropagate and update flow parameters
loss.backward()
optimizer.step()
Examples
For more examples see the SESaMo/examples folder, which contains Jupyter notebooks for the Hubbard model and the Gaussian mixture model.
Run experiments
To run the experiments from the paper, follow the instructions below.
Clone the repository and move into the directory:
git clone https://github.com/janikkreit/SESaMo.git
cd SESaMo
Create a python virtual environment and install the package:
python -m venv .venv
source .venv/bin/activate
pip install -e .
Run experiments with
cd experiments
python train.py -cp configs/<experiment> -cn <model>
Available <experiment>s are:
hubbard2x1
hubbard18x100
gaussian-mixture
broken-gaussian-mixture
complex-phi4
broken-complex-phi4
broken-scalar-phi4
Available <model>s are:
realnvp
vmonf
canonicalization
sesamo
The checkpoint, tensorboard, config and stats files are stored in the SESaMo/scripts/runs folder.
After training is completed or interupted the distribution is plotted and saved as SESaMo/scripts/runs/.../samples.png
Citation
If you use SESaMo in your research, please consider citing our paper:
@article{kreit2025sesamo,
title={SESaMo: Symmetry-Enforcing Stochastic Modulation for Normalizing Flows},
author={Janik Kreit and Dominic Schuh and Kim A. Nicoli and Lena Funcke},
year={2025},
eprint={2505.19619},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2505.19619},
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file sesamo-1.0.0.tar.gz.
File metadata
- Download URL: sesamo-1.0.0.tar.gz
- Upload date:
- Size: 23.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.21
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
47f6a7884e5a13162bc8adad92f37b3d640387c5b1a25d18ecb651b4501c3459
|
|
| MD5 |
bd8874acd8b3c3ce74a25677938e2f81
|
|
| BLAKE2b-256 |
c18c8b216b04d38a22b14a79c6293efab4e436fdf0f6ece2ffdb2b45e8a20108
|
File details
Details for the file sesamo-1.0.0-py3-none-any.whl.
File metadata
- Download URL: sesamo-1.0.0-py3-none-any.whl
- Upload date:
- Size: 27.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.21
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8b9d850cb893e26bb67791a9bca29d2a06592472a126c95aa9a71a294cd5ffea
|
|
| MD5 |
32650ed4e1ae74ed2ced5be9b67fcf34
|
|
| BLAKE2b-256 |
414223ba71faa56784e596a76037978e1ac3d91b079a6f0e6ba2b39084011a48
|