Skip to main content

JAX Synergistic Memory Inspector

Project description

JAX Synergistic Memory Inspector

jax-smi is a tool for real-time inspection of the memory usage of a JAX process. It is similar to nvidia-smi for GPU, but works on multiple platforms including CPU, GPU and TPU.

On TPU platforms, jax-smi is the only way to monitor TPU memory usage. On GPU platforms, jax-smi is also preferable to nvidia-smi. The latter is unable to report real-time memory usage of JAX processes, as JAX always pre-allocates 90% of the GPU memory by default.

This project is supported by Cloud TPUs from Google's TPU Research Cloud (TRC).

Installation

Install go. On Ubuntu, this is usually done by:

sudo apt-get install golang

If you followed tpu-starter to set up the TPU environment, go should be already installed.

Then install jax-smi with:

pip install jax-smi

Usage

In your JAX script:

from jax_smi import initialise_tracking
initialise_tracking()
# some computation...

Open a shell and run:

jax-smi

Approach

Save the memory profile to /dev/shm/memory.prof in a separate thread every 1 second using jax.profiler.save_device_memory_profile().

Inspect the memory profile with go tool pprof -tags /dev/shm/memory.prof.

See https://twitter.com/ayaka14732/status/1565013139594551296 for more details.

Limitations

Tracing can only be performed by one process at a time. If tracing is performed by multiple JAX processes, they will write the memory profiles to the same file, which will lead to conflicts.

The jax-smi command line tool cannot detect if a memory profile file is out of date. Therefore, even if no JAX process is running, the tool will still read the outdated memory profile and report outdated memory usage information.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_smi-1.0.4.tar.gz (3.8 kB view details)

Uploaded Source

Built Distribution

jax_smi-1.0.4-py3-none-any.whl (4.5 kB view details)

Uploaded Python 3

File details

Details for the file jax_smi-1.0.4.tar.gz.

File metadata

  • Download URL: jax_smi-1.0.4.tar.gz
  • Upload date:
  • Size: 3.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.19

File hashes

Hashes for jax_smi-1.0.4.tar.gz
Algorithm Hash digest
SHA256 f0bfc1b5bd9b90a6f2a638d934c4076b12bba0b94998894c7eb07058a68d5472
MD5 aca643ac68ce0dfb26859610f7357362
BLAKE2b-256 66ed8e6fe7488e7118e4108fd577162ce54132924399af60a42080a296ab51b4

See more details on using hashes here.

File details

Details for the file jax_smi-1.0.4-py3-none-any.whl.

File metadata

  • Download URL: jax_smi-1.0.4-py3-none-any.whl
  • Upload date:
  • Size: 4.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.19

File hashes

Hashes for jax_smi-1.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 4c55e7ccd638bd884b1920370eaffe6cb8841b18cec36236a3e87a6eabc6fc06
MD5 fd69a2b4042363857f45c069c207768e
BLAKE2b-256 51f41b51781413cab18213f88692be2ddc3c2422ff9e319264db8ef8da945875

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page