Skip to main content

A bounded (and autodiff friendly) while loop in JAX.

Project description

jax-bounded-while

Bounded while loop in JAX.

PyPI version PyPI platforms Actions status

This is a micro-package, containing the single function bounded_while_loop.
Reverse-mode-friendly, bounded while_loop implemented via lax.scan.

Note: This library is being renamed to jaxmore and expanded in scope. In addition to bounded_while_loop, it will include more JAX-related functionality — such as a vmap that supports keyword arguments. This will be the last release for jax-bounded-while.

Installation

pip install jax-bounded-while

Examples

Simple loop over a scalar:

import jax.numpy as jnp
from jax_bounded_while import bounded_while_loop


def cond_fn(x):
    return x < 5


def body_fn(x):
    return x + 1


result = bounded_while_loop(cond_fn, body_fn, jnp.asarray(0), max_steps=10)
print(result)  # Array(5, dtype=int32)

PyTree carry (tuple):

import jax.numpy as jnp
from jax_bounded_while import bounded_while_loop


def cond_fn(state):
    x, _ = state
    return x < 3


def body_fn(state):
    x, y = state
    return x + 1, y * 2


result = bounded_while_loop(
    cond_fn, body_fn, (jnp.asarray(0), jnp.asarray(1)), max_steps=5
)
print(result)  # (Array(3, dtype=int32), Array(8, dtype=int32))

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_bounded_while-0.1.2.tar.gz (149.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_bounded_while-0.1.2-py3-none-any.whl (6.9 kB view details)

Uploaded Python 3

File details

Details for the file jax_bounded_while-0.1.2.tar.gz.

File metadata

  • Download URL: jax_bounded_while-0.1.2.tar.gz
  • Upload date:
  • Size: 149.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for jax_bounded_while-0.1.2.tar.gz
Algorithm Hash digest
SHA256 4e09c733cdfc6d904d166dc4014bcd768373aa6680d7d568906721df35dde767
MD5 89ec3df7c713a29ab00dfbb825ef7f1d
BLAKE2b-256 83cd9b6b9b68d0e134774fd9e6127b482586e96d057d728d5a64668a27379b66

See more details on using hashes here.

Provenance

The following attestation bundles were made for jax_bounded_while-0.1.2.tar.gz:

Publisher: cd.yml on GalacticDynamics/jax-bounded-while

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jax_bounded_while-0.1.2-py3-none-any.whl.

File metadata

File hashes

Hashes for jax_bounded_while-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 c911bc9c18622ebab7575c5b946270a8b709de99151730ff9d0d93a235a9ec82
MD5 87a555b9abf71ae534385d46b2757f4e
BLAKE2b-256 49f9e0aff648bd870bf252e4b913bf0aeaa98a0bdee1622309b6066927b4c4bd

See more details on using hashes here.

Provenance

The following attestation bundles were made for jax_bounded_while-0.1.2-py3-none-any.whl:

Publisher: cd.yml on GalacticDynamics/jax-bounded-while

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page