Bijections & normalizing flows with JAX/NNX
Project description
bijection + jax = /ˈbaɪdʒæks/
Bijections & normalizing flows with JAX/NNX
This library provides flexible tools for building normalizing flows and bijections with tractable change of densities, focusing on research and applications in physics. It aims to provide reusable building blocks rather than a simplified interface for common use cases.
The library is built around two fundamental mathematical objects:
- Bijections: Invertible transformations that track their effect on probability densities.
- Distributions: Probability distributions with methods for sampling and density evaluation.
Key Features
Core design goals:
- Modular Building Blocks: Composable components rather than monolithic frameworks
- Research-Focused: Prioritizes flexibility and expressiveness over convenience
- Flax NNX Integration: Modern state management for neural network components
Physics & advanced methods:
- Continuous Normalizing Flows: CNFs as core primitive with flexible ODE solver backends
- Matrix Lie Group Operations: Automatic differentiation on matrix groups
- Structure-Preserving Integration: Crouch-Grossmann solvers for ODEs on matrix groups
- Symmetry-Aware Architectures: Equivariant layers and transformations (e.g., lattice-symmetric CNFs)
- Fourier-Space Operations: Tools for momentum space transformations and complex field decomposition
Related Publications
This library extends research software developed for the following publications, which you may find useful and may consider citing if you use the relevant components:
Quickstart
Here is a minimal example of building and sampling from a simple normalizing flow. We transform samples from a base distribution (a standard normal) using a chain of bijections to produce samples from a new, transformed distribution.
import jax
import jax.numpy as jnp
import bijx
from flax import nnx
# Define simple base distributions
prior = bijx.IndependentNormal(event_shape=(2,))
# choose a base bijection for the coupling layer
base_bijection = bijx.ModuleReconstructor(
# for real NVP style, using affine linear as base bijection
bijx.AffineLinear(rngs=nnx.Rngs(0))
)
# build a simple coupling layer
coupling_layer = bijx.GeneralCouplingLayer(
# replace with more realistic NN (active -> params)
embedding_net=lambda x: jnp.zeros(x.shape + (base_bijection.params_total_size,)),
mask=bijx.BinaryMask.from_boolean_mask(jnp.array([True, False])),
bijection_reconstructor=base_bijection,
)
# build a very simple continuous flow
# note: automatic divergence is expensive in high dimensions
cnf = bijx.ContFlowDiffrax(
# vf must return divergence too; AutoJacVF computes it for us
vf=bijx.AutoJacVF(
lambda t, x: -x,
)
)
# Compose bijections
flow = bijx.Chain(
bijx.Shift(jnp.array([5.0, -2.0])),
bijx.Scaling(jnp.array([2.0, 0.5])),
# add more coupling layers for expressivity
coupling_layer,
# final continuous flow
cnf,
)
# sample from the base distribution
rng = jax.random.PRNGKey(42)
x, lp_x = prior.sample(batch_shape=(5,), rng=rng)
# transform samples with bijection
y, lp_y = flow.forward(x, lp_x)
# these can also be generated by combining prior and flow into a new distribution
dist = bijx.Transformed(prior, flow)
y2, lp_y2 = dist.sample(batch_shape=(5,), rng=rng)
# exactly the same as we used the same random key
assert jnp.allclose(y, y2)
# check bijectivity
x_rec, lp_x_rec = flow.reverse(y, lp_y)
assert jnp.allclose(x, x_rec, atol=1e-6)
assert jnp.allclose(lp_x, lp_x_rec, atol=1e-6)
Design Principles
- Modularity: Different parts of the library should be usable on their own. Expose as many expressive building blocks as possible.
- Flexibility: Prioritize flexibility, sometimes at the cost of adding more ways to break things.
- Runtime Shape Inference: Use a
batch + space + channelsconvention and automatic vectorization for flexible data shape handling.
Installation
This package can be installed directly from PyPI:
pip install bijx
Alternatively, to install locally from source:
pip install -e .
For development, install as an editable package with all dependencies (optionally without docs if documentation building is not needed):
pip install -e ".[dev,docs]"
To keep the codebase tidy, please install pip install pre-commit and run pre-commit install before committing changes.
Documentation
To compile and open a local server for the documentation, run make livehtml in the docs/ directory.
Testing
There are two types of tests that can be run:
the unit tests in tests/ and the docstring examples in the source code src/bijx/.
# Run tests
pytest tests/
# Run doctests
pytest src/bijx/ --doctest-modules
# Run all tests
pytest tests/ src/bijx/ --doctest-modules
# Run tests in parallel using xdist
pytest -n auto
Module Layout
The library is organized into core mathematical tools, machine learning components, and specific applications.
bijx/
├── __init__.py # Main package exports
├── utils.py # General utilities
├── distributions.py # Core Distribution classes
├── samplers.py # Sampling helpers
├── solvers.py # ODE solver implementations
│
├── bijections/ # All bijection-related code
│ ├── base.py # Core Bijection, Chain, etc.
│ ├── continuous.py # Continuous flow wrappers
│ ├── conv_cnf.py # Convolutional CNF architecture
│ ├── coupling.py # Coupling layers
│ └── ...
│
├── nn/ # Neural network components (NNX)
│ ├── conv.py # Symmetric convolutions
│ ├── embeddings.py # Time and positional embeddings
│ └── features.py # Reusable feature-mappers
│
├── fourier.py # Tools for Fourier-space operations
├── lie.py # General-purpose Lie group operations
├── cg.py # Crouch-Grossmann ODE integrators
│
└── lattice/ # Application: Lattice Field Theory
├── scalar.py # Action and observables for phi^4 theory
└── gauge.py # Gauge field symmetry operations
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file bijx-1.1.1.tar.gz.
File metadata
- Download URL: bijx-1.1.1.tar.gz
- Upload date:
- Size: 931.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
13da88570145dea8f5c82045357534928e5da6dd90883c3826119032e9d47062
|
|
| MD5 |
a523ad505b25dd0cb392bf855996148b
|
|
| BLAKE2b-256 |
dd0b0b90d62ce3f9b8532e09ba501165a96a77f7ac3f89eb3cda14bb6a0b99ad
|
Provenance
The following attestation bundles were made for bijx-1.1.1.tar.gz:
Publisher:
publish.yml on mathisgerdes/bijx
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
bijx-1.1.1.tar.gz -
Subject digest:
13da88570145dea8f5c82045357534928e5da6dd90883c3826119032e9d47062 - Sigstore transparency entry: 724943825
- Sigstore integration time:
-
Permalink:
mathisgerdes/bijx@5c2d222438b87e8abd02415e50054f83bc9ae095 -
Branch / Tag:
refs/tags/v1.1.1 - Owner: https://github.com/mathisgerdes
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@5c2d222438b87e8abd02415e50054f83bc9ae095 -
Trigger Event:
release
-
Statement type:
File details
Details for the file bijx-1.1.1-py3-none-any.whl.
File metadata
- Download URL: bijx-1.1.1-py3-none-any.whl
- Upload date:
- Size: 109.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b4b5daceb37560f519252ecf48c97f19baaa7b0b5765d2f6066307da2f4d7017
|
|
| MD5 |
fad777e5e9f79b6e34aec60319262bf3
|
|
| BLAKE2b-256 |
7959c75140186371c9393abdde061d48c7f6bb1e57a0b65f13b299840a174d62
|
Provenance
The following attestation bundles were made for bijx-1.1.1-py3-none-any.whl:
Publisher:
publish.yml on mathisgerdes/bijx
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
bijx-1.1.1-py3-none-any.whl -
Subject digest:
b4b5daceb37560f519252ecf48c97f19baaa7b0b5765d2f6066307da2f4d7017 - Sigstore transparency entry: 724943828
- Sigstore integration time:
-
Permalink:
mathisgerdes/bijx@5c2d222438b87e8abd02415e50054f83bc9ae095 -
Branch / Tag:
refs/tags/v1.1.1 - Owner: https://github.com/mathisgerdes
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@5c2d222438b87e8abd02415e50054f83bc9ae095 -
Trigger Event:
release
-
Statement type: