Skip to main content

JAX compatible progress bars for scan and fori_loop.

Project description

tqdx

Adds tqdm progress bars to jax.lax.scan and jax.lax.fori_loop. Progress bars commonly used in Python, such as tqdm, are not compatible with JAX's jit-compiled functions due to restrictions on side effects like printing. tqdx addresses this limitation by using callbacks to update progress bars created on the host.

import tqdx

...
carry, ys = tqdx.scan(f, init, xs)
Processing: 100%|███████████████████████████████████████████| 50/50 [02:38<00:00,  3.20s/it]

Features

  • Progress bars for JAX: See the progress of your computations when using jax.lax.scan and jax.lax.fori_loop.
  • Works with jax.jit: Progress bars show up even inside jit-compiled code.
  • Minimal syntax change: Just replace your calls to jax.lax.scan and jax.lax.fori_loop with tqdx.scan and tqdx.fori_loop.
  • No extra dependencies: Only requires JAX and tqdm.

Usage

The following example demonstrates how to use tqdx with jax.lax.scan and jax.lax.fori_loop. You can arbitrarily nest these functions, and the progress bars will still work correctly.

import jax
import tqdx
from time import sleep

def step(carry, x):
    def body_fun(i, val):
        jax.debug.callback(lambda: sleep(0.5))
        return val + i
    jax.debug.callback(lambda: sleep(0.5))
    carry = tqdx.fori_loop(0, 10, body_fun, carry)
    return carry, x + 1

def f(xs):
    return tqdx.scan(step, 0, xs)


xs = jax.numpy.arange(10)
result, _ = jax.jit(f)(xs)

Installation

pip install tqdx

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tqdx-0.1.3.tar.gz (3.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tqdx-0.1.3-py3-none-any.whl (4.5 kB view details)

Uploaded Python 3

File details

Details for the file tqdx-0.1.3.tar.gz.

File metadata

  • Download URL: tqdx-0.1.3.tar.gz
  • Upload date:
  • Size: 3.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.1.3 CPython/3.12.10 Linux/6.8.0-49-generic

File hashes

Hashes for tqdx-0.1.3.tar.gz
Algorithm Hash digest
SHA256 fb375b501ab3e26272e997eec073d97ff69900f8b424d5039afcce5963c1a07a
MD5 6491a935faa0fbeb236ba00ee8111685
BLAKE2b-256 523445558925de9ea31d14d2e6541f325a119b95f7e912c1b5afc6851af783a1

See more details on using hashes here.

File details

Details for the file tqdx-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: tqdx-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 4.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.1.3 CPython/3.12.10 Linux/6.8.0-49-generic

File hashes

Hashes for tqdx-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 e615d03f988b3e2e8931d38477bfa1a761452b575dd256f87c08731165615bb1
MD5 9c663c7740f529bfc1ddda34e4faf3b0
BLAKE2b-256 6a40466c9b5f1323747c943f53b64afefe6a65f5eb780dd040b5ead648e86fd2

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page