Skip to main content

Tqdm progress bar for JAX scans and loops

Project description

JAX-tqdm

Add a tqdm progress bar to your JAX scans and loops.

The code is explained in this blog post.

Example usage

in jax.lax.scan

from jax_tqdm import scan_tqdm
from jax import lax
import jax.numpy as jnp

n = 10_000

@scan_tqdm(n)
def step(carry, x):
    return carry + 1, carry + 1

last_number, all_numbers = lax.scan(step, 0, jnp.arange(n))

in jax.lax.fori_loop

from jax_tqdm import loop_tqdm
from jax import lax

n = 10_000

@loop_tqdm(n)
def step(i, val):
    return val + 1

last_number = lax.fori_loop(0, n, step, 0)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_tqdm-0.1.0.tar.gz (2.7 kB view hashes)

Uploaded Source

Built Distribution

jax_tqdm-0.1.0-py3-none-any.whl (2.7 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page