Skip to main content

Progress meters for JAX loops and scans

Project description

jax-progress

Progress meters for JAX loops, scans, and Diffrax solves.

Features

  • Tqdm progress bars for JAX loops (scan, while_loop).
  • Support for vmap with correct progress tracking (skips batched updates, tracks n slowest processes).
  • Support for shard_map with device-level progress tracking.
  • diffrax compatible progress meter.

Installation

pip install jax-progress

Usage

Basic vmap example

import jax
import jax.numpy as jnp
from jax_progress import TqdmProgressMeter

# Limit to 3 progress bars (shows 3 slowest tasks)
pbar = TqdmProgressMeter(total=100, max_bars=3)

def task(data):
    state = pbar.init(vmapped_element=data)
    def body(carry, x):
        return pbar.step(carry, progress=1), x
    state, _ = jax.lax.scan(body, state, data)
    pbar.close(state)
    return data.sum()

# Run 10 tasks in parallel, but only show 3 slowest
results = jax.vmap(task)(jnp.ones((10, 100)))

shard_map example

from jax.sharding import PartitionSpec as P
from functools import partial

mesh = jax.make_mesh((4,), ('x',))
pbar = TqdmProgressMeter(total=100)

@partial(jax.shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x'))
def sharded_task(data):
    state = pbar.init(spec=P('x'))
    def body(carry, x):
        return pbar.step(carry, progress=1), x
    state, _ = jax.lax.scan(body, state, jnp.arange(100))
    pbar.close(state)
    return data

results = sharded_task(jnp.ones(4))

Diffrax integration (drop-in replacement)

TqdmProgressMeter can be used as a drop-in replacement for Diffrax's default progress meter:

import diffrax

# Create progress meter with percent_progress=True for Diffrax
pbar = TqdmProgressMeter(total=100, percent_progress=True)

# Use directly in diffeqsolve
sol = diffrax.diffeqsolve(
    term, solver, t0=0.0, t1=10.0, dt0=0.01, y0=y0,
    stepsize_controller=stepsize_controller,
    progress_meter=pbar  # Drop-in replacement
)
pbar.terminate()

Note: You can combine vmap and shard_map for multi-level parallelism. See examples/ directory for more.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_progress-0.1.0.tar.gz (12.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_progress-0.1.0-py3-none-any.whl (9.3 kB view details)

Uploaded Python 3

File details

Details for the file jax_progress-0.1.0.tar.gz.

File metadata

  • Download URL: jax_progress-0.1.0.tar.gz
  • Upload date:
  • Size: 12.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for jax_progress-0.1.0.tar.gz
Algorithm Hash digest
SHA256 700ebd006a15cb8e556fac3628a3cfef17071dbafb4c2d20f868d8d4c2b18f40
MD5 f628906a37403a8ca11409ce0d82b8e5
BLAKE2b-256 1e3c72bc8fc97cdd1b566233ce846ea81b7906d377c468b3837a3c17e77794a3

See more details on using hashes here.

Provenance

The following attestation bundles were made for jax_progress-0.1.0.tar.gz:

Publisher: publish.yml on ASKabalan/jax-progress

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jax_progress-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: jax_progress-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 9.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for jax_progress-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 3fae8ce5045a525ae40288757749e6c9d3830223c9b41d925c50b32c1ce5715e
MD5 f88c526a2c8c818d3dcd940bb8f8b946
BLAKE2b-256 3d3ea1d5c39392c0e762da98d9caaf4e5e5eea34e2ecf21f9569165daa54c288

See more details on using hashes here.

Provenance

The following attestation bundles were made for jax_progress-0.1.0-py3-none-any.whl:

Publisher: publish.yml on ASKabalan/jax-progress

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page