Skip to main content

Normalizing flows for JAX, with a distreqx-native API.

Project description

fleqx

Normalizing flows for JAX, built directly on distreqx.

A flow from fleqx is a plain distreqx.distributions.Transformed — there's no flow-specific wrapper, so log_prob, sample, optax training, etc. all work exactly as they would for any other distreqx distribution. As with any distreqx distribution, batches are handled with jax.vmap rather than by passing in arrays with a leading batch axis.

  • fleqx.train.fit accepts an arbitrary pytree for both dist and data -- e.g. a dict of independently-trained distributions and their corresponding data arrays -- not just a single distribution and a single (n, dim) array. Pytree leaves may be None for fields you're not using.
  • Each flow constructor also accepts template= in place of dim, for a distribution over an arbitrary pytree of arrays (e.g. a dict of named variables) instead of a flat vector. Requires gvcallen's distreqx fork (see Installation).
  • Bijectors that also exist in gvcallen's distreqx fork are used automatically when that fork is installed in place of the PyPI release of distreqx -- see Installation below.

Three flow types are implemented so far: coupling, masked autoregressive, and planar.

Example

import jax.numpy as jnp
import jax.random as jr

import fleqx

key = jr.key(0)
flow = fleqx.coupling_flow(key, dim=2)

sample = flow.sample(jr.key(1))
log_p = flow.log_prob(sample)

data = jr.normal(jr.key(2), (1000, 2))
flow, losses = fleqx.fit(jr.key(3), flow, data)

coupling_flow, masked_autoregressive_flow, planar_flow and fit are also available at their fully-qualified paths (fleqx.flows.coupling_flow, fleqx.train.fit, etc.), alongside the rest of each submodule.

Installation

pip install fleqx

fleqx depends only on the PyPI release of distreqx, so this is enough on its own. We recommend also installing gvcallen's distreqx fork in its place, though: it includes several additional bijectors (pending upstream review as PRs) that fleqx will prefer automatically over its own bundled fallbacks, with no fleqx-side configuration needed.

pip install git+https://github.com/gvcallen/distreqx.git@main

See the documentation for the full API.

Acknowledgements

Built on distreqx (Owen Lockwood) and parax (freezing/unwrapping of fixed sub-components, e.g. data='s standardizing layer). The bijectors were ported from flowjax (Daniel Ward), a more complete flows library that's worth using directly if you don't specifically need the distreqx API.

Development

pip install -e ".[test]"
pytest

This library was written by Claude, porting the bijectors directly from flowjax. Behaviour should be nearly identical to flowjax's, though minor differences may remain, and the code hasn't yet had a full human review.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fleqx-0.0.3.tar.gz (30.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

fleqx-0.0.3-py3-none-any.whl (28.1 kB view details)

Uploaded Python 3

File details

Details for the file fleqx-0.0.3.tar.gz.

File metadata

  • Download URL: fleqx-0.0.3.tar.gz
  • Upload date:
  • Size: 30.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for fleqx-0.0.3.tar.gz
Algorithm Hash digest
SHA256 47319a474eafd66061b5f24a462c33e5c5175f948fd772e7de0c9b2a977c806d
MD5 fa9801cf2a91f2e53375928b0b20e128
BLAKE2b-256 7b96c7ec3522e55de4334d9f6a66b8d44765fdd63d0d507033ea9b520daab48a

See more details on using hashes here.

Provenance

The following attestation bundles were made for fleqx-0.0.3.tar.gz:

Publisher: publish.yml on gvcallen/fleqx

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file fleqx-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: fleqx-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 28.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for fleqx-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 509fa2d5c9157e320adb4a7faa6c1f039c913e4f555093c86fd11faaccba79f8
MD5 f24888f743c3e02466f5dbc6a5412a9d
BLAKE2b-256 4917a86396ed6eca235b8f0e6fdcb6370573375d4658b876e13c6860cf0051a8

See more details on using hashes here.

Provenance

The following attestation bundles were made for fleqx-0.0.3-py3-none-any.whl:

Publisher: publish.yml on gvcallen/fleqx

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page