Skip to main content

JAX compatible progress bars for scan and fori_loop.

Project description

tqdx

Adds tqdm progress bars to jax.lax.scan and jax.lax.fori_loop. Progress bars commonly used in Python, such as tqdm, are not compatible with JAX's jit-compiled functions due to restrictions on side effects like printing. tqdx addresses this limitation by using callbacks to update progress bars created on the host.

import tqdx

...
carry, ys = tqdx.scan(f, init, xs)
Processing: 100%|███████████████████████████████████████████| 50/50 [02:38<00:00,  3.20s/it]

Features

  • Progress bars for JAX: See the progress of your computations when using jax.lax.scan and jax.lax.fori_loop.
  • Works with jax.jit: Progress bars show up even inside jit-compiled code.
  • Minimal syntax change: Just replace your calls to jax.lax.scan and jax.lax.fori_loop with tqdx.scan and tqdx.fori_loop.
  • No extra dependencies: Only requires JAX and tqdm.

Usage

The following example demonstrates how to use tqdx with jax.lax.scan and jax.lax.fori_loop. You can arbitrarily nest these functions, and the progress bars will still work correctly.

import jax
import tqdx
from time import sleep

def step(carry, x):
    def body_fun(i, val):
        jax.debug.callback(lambda: sleep(0.5))
        return val + i
    jax.debug.callback(lambda: sleep(0.5))
    carry = tqdx.fori_loop(0, 10, body_fun, carry)
    return carry, x + 1

def f(xs):
    return tqdx.scan(step, 0, xs)


xs = jax.numpy.arange(10)
result, _ = jax.jit(f)(xs)

Installation

pip install tqdx

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tqdx-0.1.0.tar.gz (3.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tqdx-0.1.0-py3-none-any.whl (4.4 kB view details)

Uploaded Python 3

File details

Details for the file tqdx-0.1.0.tar.gz.

File metadata

  • Download URL: tqdx-0.1.0.tar.gz
  • Upload date:
  • Size: 3.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.1.3 CPython/3.12.10 Linux/6.8.0-49-generic

File hashes

Hashes for tqdx-0.1.0.tar.gz
Algorithm Hash digest
SHA256 f6cf5d01132bc9083c2b4f2f6e31b3eac6c4b73cb69c698d31b65c82b0c1b45b
MD5 1722a17065a3c1394bb96fa6bd2de986
BLAKE2b-256 719bb15c1eacc562b0ab0da166786128117eeb6a8601a7fe4516725ecfcb728b

See more details on using hashes here.

File details

Details for the file tqdx-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: tqdx-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 4.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.1.3 CPython/3.12.10 Linux/6.8.0-49-generic

File hashes

Hashes for tqdx-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 d5253ab9f96c6af9f5be36f7de830cb1e431c94086206ec8d6bf42dedf166720
MD5 602441853b9f79e52853aa1a7cfa9aa6
BLAKE2b-256 d0f02cbfa5615c0970c6275e9a9b7b88878fee2dc15b75550ac2dd8ce311f44c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page