Skip to main content

Tqdm progress bar for JAX scans and loops

Project description

JAX-Tqdm

Add a tqdm progress bar to your JAX scans and loops.

PyPI - Version PyPI - Downloads

Installation

Install with pip:

pip install jax-tqdm

Example Usage

In jax.lax.scan

from jax_tqdm import scan_tqdm
from jax import lax
import jax.numpy as jnp

n = 10_000

@scan_tqdm(n)
def step(carry, x):
    return carry + 1, carry + 1

last_number, all_numbers = lax.scan(step, 0, jnp.arange(n))

Where the range argument must start at 0. A tuple can be used to pass data to the scan, as long as the first entry is a range, e.g.:

from jax_tqdm import scan_tqdm
from jax import lax
import jax.numpy as jnp

n = 10
scan_data = jnp.zeros((n, 200, 100))

@scan_tqdm(n)
def step(carry, x):
    _, d = x
    return carry + 1, d

last_number, output_data = lax.scan(step, 0, (jnp.arange(n), scan_data))

In jax.lax.fori_loop

from jax_tqdm import loop_tqdm
from jax import lax

n = 10_000

@loop_tqdm(n)
def step(i, val):
    return val + 1

last_number = lax.fori_loop(0, n, step, 0)

where the initial loop value should start at 0.

Scans & Loops Inside Vmap

For scans and loops inside a map, jax-tqdm can print stacked progress bars showing the individual progress of each process. To do this you can wrap the initial value of the loop or scan inside a PBar class, along with the index of the progress bar. For example

from jax_tqdm import PBar, scan_tqdm
import jax

n = 10_000

@scan_tqdm(n)
def step(carry, _):
    return carry + 1, carry + 1

def map_func(i):
    # Wrap the initial value and pass the
    # progress bar index
    init = PBar(id=i, carry=0)
    final_value, _all_numbers = jax.lax.scan(
        step, init, jax.numpy.arange(n)
    )
    return (
        final_value.carry,
        _all_numbers,
    )

last_numbers, all_numbers = jax.vmap(map_func)(jax.numpy.arange(10))

The indices of the progress bars should be contiguous integers starting from 0.

Print Rate

By default, the progress bar is updated 20 times over the course of the scan/loop (for performance purposes, see below). This update rate can be manually controlled with the print_rate keyword argument. For example:

from jax_tqdm import scan_tqdm
from jax import lax
import jax.numpy as jnp

n = 10_000

@scan_tqdm(n, print_rate=2)
def step(carry, x):
    return carry + 1, carry + 1

last_number, all_numbers = lax.scan(step, 0, jnp.arange(n))

will update every other step.

Progress Bar Type

You can select the tqdm submodule manually with the tqdm_type option. The options are 'std', 'notebook', or 'auto'.

from jax_tqdm import scan_tqdm
from jax import lax
import jax.numpy as jnp

n = 10_000

@scan_tqdm(n, print_rate=1, tqdm_type='std') # tqdm_type='std' or 'notebook' or 'auto'
def step(carry, x):
    return carry + 1, carry + 1

last_number, all_numbers = lax.scan(step, 0, jnp.arange(n))

Progress Bar Options

Any additional keyword arguments are passed to the tqdm progress bar constructor. For example:

from jax_tqdm import scan_tqdm
from jax import lax
import jax.numpy as jnp

n = 10_000

@scan_tqdm(n, print_rate=1, desc='progress bar', position=0, leave=False)
def step(carry, x):
    return carry + 1, carry + 1

last_number, all_numbers = lax.scan(step, 0, jnp.arange(n))

Why JAX-Tqdm?

JAX functions are pure, so side effects such as printing progress when running scans and loops are not allowed. However, the debug module has primitives for calling Python functions on the host from JAX code. This can be used to update a Python tqdm progress bar regularly during the computation. JAX-tqdm implements this for JAX scans and loops and is used by simply adding a decorator to the body of your update function.

Note that as the tqdm progress bar is only updated 20 times during the scan or loop, there is no performance penalty.

The code is explained in more detail in this blog post.

Developers

Dependencies can be installed with poetry by running

poetry install

Pre-Commit Hooks

Pre commit hooks can be installed by running

pre-commit install

Pre-commit checks can then be run using

task lint

Tests

Tests can be run with

task test

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_tqdm-0.4.0.tar.gz (5.6 kB view details)

Uploaded Source

Built Distribution

jax_tqdm-0.4.0-py3-none-any.whl (7.2 kB view details)

Uploaded Python 3

File details

Details for the file jax_tqdm-0.4.0.tar.gz.

File metadata

  • Download URL: jax_tqdm-0.4.0.tar.gz
  • Upload date:
  • Size: 5.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.5.0 CPython/3.10.6 Linux/6.12.10-76061203-generic

File hashes

Hashes for jax_tqdm-0.4.0.tar.gz
Algorithm Hash digest
SHA256 d76a7ab07286ed8024fa019add2668b4d1d6ddf3bcaa166f2e1466d9e51b99a0
MD5 00f3c1204e1fc0b7d7b7a00145003b77
BLAKE2b-256 d9ddb132ae4a7e22604c2738dff805d151fd81f6a3e77fab36c8dcbd43003f9e

See more details on using hashes here.

File details

Details for the file jax_tqdm-0.4.0-py3-none-any.whl.

File metadata

  • Download URL: jax_tqdm-0.4.0-py3-none-any.whl
  • Upload date:
  • Size: 7.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.5.0 CPython/3.10.6 Linux/6.12.10-76061203-generic

File hashes

Hashes for jax_tqdm-0.4.0-py3-none-any.whl
Algorithm Hash digest
SHA256 973d81946ea22d2cad30f2a0434c072b7471ecc23baea2c42ec65e0b2bc2f41f
MD5 426615e2e0cbd1733f7e7781a11302a4
BLAKE2b-256 1d2da94edc83d0012a5968abe1a16a6dc402100a76323fb30e229afe0f75bfe3

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page