Skip to main content

Efficiently get the index-0 element of an iterable.

Project description

jax-bounded-while

Bounded while loop in JAX.

PyPI version PyPI platforms Actions status

This is a micro-package, containing the single function bounded_while_loop.
Reverse-mode-friendly, bounded while_loop implemented via lax.scan.

Installation

pip install jax-bounded-while

Examples

Simple loop over a scalar:

import jax.numpy as jnp
from jax_bounded_while import bounded_while_loop


def cond_fn(x):
    return x < 5


def body_fn(x):
    return x + 1


result = bounded_while_loop(cond_fn, body_fn, jnp.asarray(0), max_steps=10)
print(result)  # Array(5, dtype=int32)

PyTree carry (tuple):

import jax.numpy as jnp
from jax_bounded_while import bounded_while_loop


def cond_fn(state):
    x, _ = state
    return x < 3


def body_fn(state):
    x, y = state
    return x + 1, y * 2


result = bounded_while_loop(
    cond_fn, body_fn, (jnp.asarray(0), jnp.asarray(1)), max_steps=5
)
print(result)  # (Array(3, dtype=int32), Array(8, dtype=int32))

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_bounded_while-0.1.tar.gz (149.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_bounded_while-0.1-py3-none-any.whl (6.7 kB view details)

Uploaded Python 3

File details

Details for the file jax_bounded_while-0.1.tar.gz.

File metadata

  • Download URL: jax_bounded_while-0.1.tar.gz
  • Upload date:
  • Size: 149.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for jax_bounded_while-0.1.tar.gz
Algorithm Hash digest
SHA256 efb69ed3f90f7f65f751947861a4ecbbaac71cef623b738c15f86a5309a4788a
MD5 83654a5a86b17a5f4556921243e3ab47
BLAKE2b-256 09f740203b0a258d1d4398fff25e50fab6fbbed81be45b7b74267b118b13d82f

See more details on using hashes here.

Provenance

The following attestation bundles were made for jax_bounded_while-0.1.tar.gz:

Publisher: cd.yml on GalacticDynamics/jax-bounded-while

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jax_bounded_while-0.1-py3-none-any.whl.

File metadata

File hashes

Hashes for jax_bounded_while-0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 752c33939eee34be808c699b140e43db8b3a5f22f72841a55fb2927a6aa876cb
MD5 fd0e2b0f6f334206b98955e78996d13a
BLAKE2b-256 01f969effe844743fdeea433c5fe140d9cc9fd7003573ef406c090c34b4f9dde

See more details on using hashes here.

Provenance

The following attestation bundles were made for jax_bounded_while-0.1-py3-none-any.whl:

Publisher: cd.yml on GalacticDynamics/jax-bounded-while

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page