Skip to main content

JAX compatible progress bars for scan and fori_loop.

Project description

tqdx

Adds tqdm progress bars to jax.lax.scan and jax.lax.fori_loop. Progress bars commonly used in Python, such as tqdm, are not compatible with JAX's jit-compiled functions due to restrictions on side effects like printing. tqdx addresses this limitation by using callbacks to update progress bars created on the host.

import tqdx

...
carry, ys = tqdx.scan(f, init, xs)
Processing: 100%|███████████████████████████████████████████| 50/50 [02:38<00:00,  3.20s/it]

Features

  • Progress bars for JAX: See the progress of your computations when using jax.lax.scan and jax.lax.fori_loop.
  • Works with jax.jit: Progress bars show up even inside jit-compiled code.
  • Minimal syntax change: Just replace your calls to jax.lax.scan and jax.lax.fori_loop with tqdx.scan and tqdx.fori_loop.
  • No extra dependencies: Only requires JAX and tqdm.

Usage

The following example demonstrates how to use tqdx with jax.lax.scan and jax.lax.fori_loop. You can arbitrarily nest these functions, and the progress bars will still work correctly.

import jax
import tqdx
from time import sleep

def step(carry, x):
    def body_fun(i, val):
        jax.debug.callback(lambda: sleep(0.5))
        return val + i
    jax.debug.callback(lambda: sleep(0.5))
    carry = tqdx.fori_loop(0, 10, body_fun, carry)
    return carry, x + 1

def f(xs):
    return tqdx.scan(step, 0, xs)


xs = jax.numpy.arange(10)
result, _ = jax.jit(f)(xs)

Installation

pip install tqdx

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tqdx-0.1.1.tar.gz (3.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tqdx-0.1.1-py3-none-any.whl (4.4 kB view details)

Uploaded Python 3

File details

Details for the file tqdx-0.1.1.tar.gz.

File metadata

  • Download URL: tqdx-0.1.1.tar.gz
  • Upload date:
  • Size: 3.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.1.3 CPython/3.12.10 Linux/6.8.0-49-generic

File hashes

Hashes for tqdx-0.1.1.tar.gz
Algorithm Hash digest
SHA256 bab475edf4f68a912b378a71a417a9f5c75ea28c1a9e83fc1d89defb56f28673
MD5 2e3b1572c3eedf3fb11c1ee8e2b6aa6c
BLAKE2b-256 c436c7dc4352321602ed33fad092614d2587b248a75ab586e156d364c713ddbc

See more details on using hashes here.

File details

Details for the file tqdx-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: tqdx-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 4.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.1.3 CPython/3.12.10 Linux/6.8.0-49-generic

File hashes

Hashes for tqdx-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 1c6a57936325cb26bfdfcdbf12b15683fcf5da9491eed289ea94270e85eff695
MD5 933d64868708118d4b2afee66869c788
BLAKE2b-256 6b5faa0cc39f6894709e9e1bf52e203320e00c6a6858f23b4ccb936f9ebec47f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page