Skip to main content

A comprehensive benchmarking and profiling tool designed for JAX in HPC environments, offering automated instrumentation, strong/weak scaling analysis, and performance visualization.

Project description

JAX HPC Profiler

Build Code Formatting Tests Notebooks GPLv3 License

JAX HPC Profiler is a tool designed for benchmarking and visualizing performance data in high-performance computing (HPC) environments. It provides functionalities to generate, concatenate, and plot CSV data from various runs.

Table of Contents

Introduction

JAX HPC Profiler allows users to:

  1. Generate CSV files containing performance data.
  2. Concatenate multiple CSV files from different runs.
  3. Plot the performance data for analysis.

Installation

To install the package, run the following command:

pip install jax-hpc-profiler

Generating CSV Files Using the Timer Class

To generate CSV files, you can use the Timer class provided in the jax_hpc_profiler.timer module. This class helps in timing functions and saving the timing results to CSV files.

Example Usage

import jax
from jax_hpc_profiler import Timer

def fcn(m, n, k):
    return jax.numpy.dot(m, n) + k

timer = Timer(save_jaxpr=True)
m = jax.numpy.ones((1000, 1000))
n = jax.numpy.ones((1000, 1000))
k = jax.numpy.ones((1000, 1000))

timer.chrono_jit(fcn, m, n, k)
for i in range(10):
    timer.chrono_fun(fcn, m, n, k)

meta_data = {
  "function": "fcn",
  "precision": "float32",
  "x": 1000,
  "y": 1000,
  "z": 1000,
  "px": 1,
  "py": 1,
  "backend": "NCCL",
  "nodes": 1
}
extra_info = {
    "done": "yes"
}

timer.report("examples/profiling/test.csv", **meta_data,  extra_info=extra_info)

timer.report has sensible defaults and this is the API for the Timer class:

  • csv_filename: The path to the CSV file to save the timing data (required).
  • function: The name of the function being timed (required).
  • x: The size of the input data in the x dimension (required).
  • y: The size of the input data in the y dimension (by default same as x).
  • z: The size of the input data in the z dimension (by default same as x).
  • precision: The precision of the data (default: "float32").
  • px: The number of partitions in the x dimension (default: 1).
  • py: The number of partitions in the y dimension (default: 1).
  • backend: The backend used for computation (default: "NCCL").
  • nodes: The number of nodes used for computation (default: 1).
  • md_filename: The path to the markdown file containing the compiled code and other information (default: {csv_folder}/{x}{px}{py}{backend}{precision}_{function}.md).
  • extra_info: Additional information to include in the report (default: {}

px and py are used to specify the data decomposition. For example, if you have a 2D array of size 1000x1000 and you partition it into 4 parts (2x2), you would set px=2 and py=2.
they can also be used in a single device run to specify batch size.

Some decomposition parameters are generated and that are specific to 3D data decomposition.
slab_yz if the distributed axis is the y-axis.
slab_xy if the distributed axis is the x-axis.
pencils if the distributed axis are the x and y axes.

Multi-GPU Setup

In a multi-GPU setup, the times are automatically averaged across ranks, providing a single performance metric for the entire setup.

CSV Structure

The CSV files should follow a specific structure to ensure proper processing and concatenation. The directory structure should be organized by GPU type, with subdirectories for the number of GPUs and the respective CSV files.

Example Directory Structure

root_directory/
├── gpu_1/
│   ├── 2/
│   │   ├── method_1.csv
│   │   ├── method_2.csv
│   │   └── method_3.csv
│   ├── 4/
│   │   ├── method_1.csv
│   │   ├── method_2.csv
│   │   └── method_3.csv
│   └── 8/
│       ├── method_1.csv
│       ├── method_2.csv
│       └── method_3.csv
└── gpu_2/
    ├── 2/
    │   ├── method_1.csv
    │   ├── method_2.csv
    │   └── method_3.csv
    ├── 4/
    │   ├── method_1.csv
    │   ├── method_2.csv
    │   └── method_3.csv
    └── 8/
        ├── method_1.csv
        ├── method_2.csv
        └── method_3.csv

Concatenating Files from Different Runs

The plot function expects the directory to be organized as described above, but with the different number of GPUs together in the same directory. The concatenate function can be used to concatenate the CSV files from different runs into a single file.

Example Usage

jhp concat /path/to/root_directory /path/to/output

And the output will be:

out_directory/
├── gpu_1/
│   ├── method_1.csv
│   ├── method_2.csv
│   └── method_3.csv
└── gpu_2/
    ├── method_1.csv
    ├── method_2.csv
    └── method_3.csv

Inspecting CSV Metadata

You can inspect available metadata in your CSV files using the probe command:

jhp probe -f <csv_files>

This prints the available data sizes, GPU counts, functions, backends, precisions, and other metadata found in the CSV files. Use this to understand what filters to apply before plotting.

Plotting CSV Data

You can plot the performance data using the plot command. The plotting command provides various options to customize the plots.

Usage

jhp plot -f <csv_files> [options]

Options

  • -f, --csv_files: List of CSV files to plot (required).
  • -sc, --scaling: Axis mode (required):
    • data (or d): subplots per data size, x-axis = GPUs (strong scaling view).
    • GPUs (or g): subplots per GPU count, x-axis = data size.
  • -g, --gpus: List of GPU counts to filter.
  • -d, --data_size: Data size queries. Examples: global_2097152, global_128x128x128, local_2097152, local_128x128x128. Bare integers are auto-translated to global_NxNxN (cubed).
  • -fd, --filter_pdims: List of pdims to filter (e.g., 1x4 2x2 4x8).
  • -ps, --pdim_strategy: Strategy for plotting pdims (plot_all, plot_fastest, slab_yz, slab_xy, pencils).
  • -pr, --precision: Precision to filter by (float32, float64).
  • -fn, --function_name: Function names to filter.
  • -pt, --plot_times: Time columns to plot (jit_time, min_time, max_time, mean_time, std_time, last_time). Note: You cannot plot memory and time together.
  • -pm, --plot_memory: Memory columns to plot (generated_code, argument_size, output_size, temp_size). Note: You cannot plot memory and time together.
  • -mu, --memory_units: Memory units to plot (KB, MB, GB, TB).
  • -fs, --figure_size: Figure size.
  • -o, --output: Output file (if none then only show plot).
  • -pd, --print_decompositions: Print decompositions on plot (experimental).
  • -b, --backends: List of backends to include.
  • --ideal_line: Overlay an ideal scaling reference line (1/N for global data sizes, flat for local data sizes).
  • -xs, --xscale: X-axis scale (linear, symlog, log2, log10).
  • -xl, --xlabel: Custom x-axis label.
  • -tl, --title: Custom plot title.
  • -l, --label_text: Custom label for the plot. You can use placeholders: %decomposition% (or %p%), %precision% (or %pr%), %plot_name% (or %pn%), %backend% (or %b%), %node% (or %n%), %methodname% (or %m%), %function% (or %f%).

CLI examples

Strong scaling (subplots per data size, x-axis = GPUs):

jhp plot -f DATA.csv -sc data -d 128 256 512 -pt mean_time --ideal_line

Size scaling (subplots per GPU count, x-axis = data size):

jhp plot -f DATA.csv -sc GPUs -pt mean_time

Examples

The repository includes Jupyter notebook examples:

  • examples/profiling.ipynb: Single-device profiling of JAX and NumPy functions with Timer, CSV report generation, and plotting with plot_by_gpus.
  • examples/distributed_profiling.ipynb: Multi-device profiling with sharded arrays, plot_by_gpus, plot_by_data_size, probe_csv_metadata, and CLI usage.

A multi-GPU example comparing distributed FFT can be found here: jaxdecomp-benchmarks

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_hpc_profiler-0.3.1.tar.gz (61.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_hpc_profiler-0.3.1-py3-none-any.whl (45.1 kB view details)

Uploaded Python 3

File details

Details for the file jax_hpc_profiler-0.3.1.tar.gz.

File metadata

  • Download URL: jax_hpc_profiler-0.3.1.tar.gz
  • Upload date:
  • Size: 61.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for jax_hpc_profiler-0.3.1.tar.gz
Algorithm Hash digest
SHA256 07704565e40a5fbc942f3ec4f6258de1461e596f76bd5b492c9eb6e5b88653a2
MD5 c52cbb8bd9a57a96087ae3031dbe9e35
BLAKE2b-256 c8b22651a3b3a1d32d8689723159f3d6d09739e370dd543f622a44e58cbab71d

See more details on using hashes here.

Provenance

The following attestation bundles were made for jax_hpc_profiler-0.3.1.tar.gz:

Publisher: python-publish.yml on ASKabalan/jax-hpc-profiler

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jax_hpc_profiler-0.3.1-py3-none-any.whl.

File metadata

File hashes

Hashes for jax_hpc_profiler-0.3.1-py3-none-any.whl
Algorithm Hash digest
SHA256 3bf2a3c92e75270d3bcc9b3c105a736d7285053e3a85221193bdf504ae80b73c
MD5 c67eda68a640d8e3752df1b3df493883
BLAKE2b-256 0d6ad02bc992aab47a38eacc6fe11472c3f744a6a16e01409a9a3ef16764dce4

See more details on using hashes here.

Provenance

The following attestation bundles were made for jax_hpc_profiler-0.3.1-py3-none-any.whl:

Publisher: python-publish.yml on ASKabalan/jax-hpc-profiler

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page