Skip to main content

JAX compatible progress bars for scan and fori_loop.

Project description

tqdx

Adds tqdm progress bars to jax.lax.scan and jax.lax.fori_loop. Progress bars commonly used in Python, such as tqdm, are not compatible with JAX's jit-compiled functions due to restrictions on side effects like printing. tqdx addresses this limitation by using callbacks to update progress bars created on the host.

import tqdx

...
carry, ys = tqdx.scan(f, init, xs)
Processing: 100%|███████████████████████████████████████████| 50/50 [02:38<00:00,  3.20s/it]

Features

  • Progress bars for JAX: See the progress of your computations when using jax.lax.scan and jax.lax.fori_loop.
  • Works with jax.jit: Progress bars show up even inside jit-compiled code.
  • Minimal syntax change: Just replace your calls to jax.lax.scan and jax.lax.fori_loop with tqdx.scan and tqdx.fori_loop.
  • No extra dependencies: Only requires JAX and tqdm.

Usage

The following example demonstrates how to use tqdx with jax.lax.scan and jax.lax.fori_loop. You can arbitrarily nest these functions, and the progress bars will still work correctly.

import jax
import tqdx
from time import sleep

def step(carry, x):
    def body_fun(i, val):
        jax.debug.callback(lambda: sleep(0.5))
        return val + i
    jax.debug.callback(lambda: sleep(0.5))
    carry = tqdx.fori_loop(0, 10, body_fun, carry)
    return carry, x + 1

def f(xs):
    return tqdx.scan(step, 0, xs)


xs = jax.numpy.arange(10)
result, _ = jax.jit(f)(xs)

Installation

pip install tqdx

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tqdx-0.1.4.tar.gz (3.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tqdx-0.1.4-py3-none-any.whl (6.5 kB view details)

Uploaded Python 3

File details

Details for the file tqdx-0.1.4.tar.gz.

File metadata

  • Download URL: tqdx-0.1.4.tar.gz
  • Upload date:
  • Size: 3.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.4

File hashes

Hashes for tqdx-0.1.4.tar.gz
Algorithm Hash digest
SHA256 5b3dbf82ddd18807e84ab2cb61489a7413337ac0c11fe9ffeb2549bac0f3cdc7
MD5 19690086fcd769fe8be27e39fb5e78c7
BLAKE2b-256 987034c94031fee59b2c13e6a65973ae52bc84c8561f49a458629e8f1eefb8d7

See more details on using hashes here.

File details

Details for the file tqdx-0.1.4-py3-none-any.whl.

File metadata

  • Download URL: tqdx-0.1.4-py3-none-any.whl
  • Upload date:
  • Size: 6.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.4

File hashes

Hashes for tqdx-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 24b7ec8c6aaeaf6a2de1ec8486c37f598cd776e3159ea43383209b6f49aa1ef8
MD5 d8224fa09fd159c70795576fd2f248c5
BLAKE2b-256 823184fd187411530fb9dbdd944b5f8d6a52718a861b6f91c882add73e1f8d9e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page