Tqdm progress bar for JAX scans and loops
Project description
JAX-tqdm
Add a tqdm progress bar to your JAX scans and loops.
The code is explained in this blog post.
Example usage
in jax.lax.scan
from jax_tqdm import scan_tqdm
from jax import lax
import jax.numpy as jnp
n = 10_000
@scan_tqdm(n)
def step(carry, x):
return carry + 1, carry + 1
last_number, all_numbers = lax.scan(step, 0, jnp.arange(n))
in jax.lax.fori_loop
from jax_tqdm import loop_tqdm
from jax import lax
n = 10_000
@loop_tqdm(n)
def step(i, val):
return val + 1
last_number = lax.fori_loop(0, n, step, 0)
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
jax_tqdm-0.1.0.tar.gz
(2.7 kB
view hashes)