Skip to main content

Normalizing flows for JAX, with a distreqx-native API.

Project description

fleqx

Normalizing flows for JAX, built directly on distreqx.

A flow from fleqx is a plain distreqx.distributions.Transformed — there's no flow-specific wrapper, so log_prob, sample, optax training, etc. all work exactly as they would for any other distreqx distribution. As with any distreqx distribution, batches are handled with jax.vmap rather than by passing in arrays with a leading batch axis.

Three flow types are implemented so far: coupling, masked autoregressive, and planar.

Example

import jax.numpy as jnp
import jax.random as jr

import fleqx

key = jr.key(0)
flow = fleqx.flows.coupling_flow(key, dim=2)

sample = flow.sample(jr.key(1))
log_p = flow.log_prob(sample)

data = jr.normal(jr.key(2), (1000, 2))
flow, losses = fleqx.train.fit(jr.key(3), flow, data)

Installation

pip install fleqx

See the documentation for the full API.

Acknowledgements

Built on distreqx (Owen Lockwood). The bijectors were ported from flowjax (Daniel Ward), a more complete flows library that's worth using directly if you don't specifically need the distreqx API.

Development

pip install -e ".[test]"
pytest

This library was written by Claude, porting the bijectors directly from flowjax. Behaviour should be nearly identical to flowjax's, though minor differences may remain, and the code hasn't yet had a full human review.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fleqx-0.0.2.tar.gz (22.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

fleqx-0.0.2-py3-none-any.whl (22.5 kB view details)

Uploaded Python 3

File details

Details for the file fleqx-0.0.2.tar.gz.

File metadata

  • Download URL: fleqx-0.0.2.tar.gz
  • Upload date:
  • Size: 22.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for fleqx-0.0.2.tar.gz
Algorithm Hash digest
SHA256 bd5688a013861fa6fa66105888605e0ed8f72a5704c0acb33f97391847284222
MD5 ba28eb9cfe11bff949bbd42c39ac2006
BLAKE2b-256 3a3af6439c38bc777178b7bb75d36ec3108838cb56f90227a7b49242712ec8d4

See more details on using hashes here.

Provenance

The following attestation bundles were made for fleqx-0.0.2.tar.gz:

Publisher: publish.yml on gvcallen/fleqx

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file fleqx-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: fleqx-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 22.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for fleqx-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 703b6cdce4386d42f55ec1952595414bba809b25d83f80a8af9d40c1801027aa
MD5 329dc1102c716f0d61dcbdee8abc8190
BLAKE2b-256 713be77b59dc6b3e4a20e6ef45a95d5cb42e84bc5c6e9ab891208cfb99c51ad9

See more details on using hashes here.

Provenance

The following attestation bundles were made for fleqx-0.0.2-py3-none-any.whl:

Publisher: publish.yml on gvcallen/fleqx

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page