Skip to main content

JAX compatible progress bars for scan and fori_loop.

Project description

tqdx

Adds tqdm progress bars to jax.lax.scan and jax.lax.fori_loop. Progress bars commonly used in Python, such as tqdm, are not compatible with JAX's jit-compiled functions due to restrictions on side effects like printing. tqdx addresses this limitation by using callbacks to update progress bars created on the host.

import tqdx

...
carry, ys = tqdx.scan(f, init, xs)
Processing: 100%|███████████████████████████████████████████| 50/50 [02:38<00:00,  3.20s/it]

Features

  • Progress bars for JAX: See the progress of your computations when using jax.lax.scan and jax.lax.fori_loop.
  • Works with jax.jit: Progress bars show up even inside jit-compiled code.
  • Minimal syntax change: Just replace your calls to jax.lax.scan and jax.lax.fori_loop with tqdx.scan and tqdx.fori_loop.
  • No extra dependencies: Only requires JAX and tqdm.

Usage

The following example demonstrates how to use tqdx with jax.lax.scan and jax.lax.fori_loop. You can arbitrarily nest these functions, and the progress bars will still work correctly.

import jax
import tqdx
from time import sleep

def step(carry, x):
    def body_fun(i, val):
        jax.debug.callback(lambda: sleep(0.5))
        return val + i
    jax.debug.callback(lambda: sleep(0.5))
    carry = tqdx.fori_loop(0, 10, body_fun, carry)
    return carry, x + 1

def f(xs):
    return tqdx.scan(step, 0, xs)


xs = jax.numpy.arange(10)
result, _ = jax.jit(f)(xs)

Installation

pip install tqdx

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tqdx-0.1.2.tar.gz (3.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tqdx-0.1.2-py3-none-any.whl (4.4 kB view details)

Uploaded Python 3

File details

Details for the file tqdx-0.1.2.tar.gz.

File metadata

  • Download URL: tqdx-0.1.2.tar.gz
  • Upload date:
  • Size: 3.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.1.3 CPython/3.12.10 Linux/6.8.0-49-generic

File hashes

Hashes for tqdx-0.1.2.tar.gz
Algorithm Hash digest
SHA256 79e03aaf6d07680b31c63511918deccdb34734f3de27799fbce3cc1dfd100be2
MD5 b02ca95a8edb72349d0232b017ee53bf
BLAKE2b-256 e451a210cb44ae1551932f53a898319e16dbe7588d9f1827864630034069cc77

See more details on using hashes here.

File details

Details for the file tqdx-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: tqdx-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 4.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.1.3 CPython/3.12.10 Linux/6.8.0-49-generic

File hashes

Hashes for tqdx-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 e523af4cf8d980edffeb5bf69ebf43a45dee2f43f565024a6129e58103c3268e
MD5 972af8b5d9fdf9b4f6f5ea78b577dd19
BLAKE2b-256 8b300deca465ac250a35a28024219d7b0cc5716735cf6d5fd5d9eea00daf348d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page