Skip to main content

A bounded (and autodiff friendly) while loop in JAX.

Project description

jax-bounded-while

Bounded while loop in JAX.

PyPI version PyPI platforms Actions status

This is a micro-package, containing the single function bounded_while_loop.
Reverse-mode-friendly, bounded while_loop implemented via lax.scan.

Installation

pip install jax-bounded-while

Examples

Simple loop over a scalar:

import jax.numpy as jnp
from jax_bounded_while import bounded_while_loop


def cond_fn(x):
    return x < 5


def body_fn(x):
    return x + 1


result = bounded_while_loop(cond_fn, body_fn, jnp.asarray(0), max_steps=10)
print(result)  # Array(5, dtype=int32)

PyTree carry (tuple):

import jax.numpy as jnp
from jax_bounded_while import bounded_while_loop


def cond_fn(state):
    x, _ = state
    return x < 3


def body_fn(state):
    x, y = state
    return x + 1, y * 2


result = bounded_while_loop(
    cond_fn, body_fn, (jnp.asarray(0), jnp.asarray(1)), max_steps=5
)
print(result)  # (Array(3, dtype=int32), Array(8, dtype=int32))

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_bounded_while-0.1.1.tar.gz (149.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_bounded_while-0.1.1-py3-none-any.whl (6.7 kB view details)

Uploaded Python 3

File details

Details for the file jax_bounded_while-0.1.1.tar.gz.

File metadata

  • Download URL: jax_bounded_while-0.1.1.tar.gz
  • Upload date:
  • Size: 149.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for jax_bounded_while-0.1.1.tar.gz
Algorithm Hash digest
SHA256 9e752c105981caa1b25baa0e86b6e06901d3c8b4fcba71a6117372fd75641a7c
MD5 896e297b8ccccd1f3a892cdedfa0ed51
BLAKE2b-256 7a8b8633f7405acc892c86fa83afb74d650915a39840753a2cac254fc3eedc91

See more details on using hashes here.

Provenance

The following attestation bundles were made for jax_bounded_while-0.1.1.tar.gz:

Publisher: cd.yml on GalacticDynamics/jax-bounded-while

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jax_bounded_while-0.1.1-py3-none-any.whl.

File metadata

File hashes

Hashes for jax_bounded_while-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 47c7f8c1799f6fb6bdb33e690abd6e48636a379a59cd201dc01581b72fbdce62
MD5 c12896c250ff70a69336f2646405eb64
BLAKE2b-256 ea866d0dab2fc9be8d7641601a3eaeaf4a54182a68bd0cf5265b6775d0fd92df

See more details on using hashes here.

Provenance

The following attestation bundles were made for jax_bounded_while-0.1.1-py3-none-any.whl:

Publisher: cd.yml on GalacticDynamics/jax-bounded-while

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page