Skip to main content

Pytest plugin to profile jitted JAX functions (compile time, runtime, memory).

Project description

pytest-jax-bench

A pytest plugin to benchmark memory usage, compilation time and run time of jitted JAX functions. This code is inspired by pytest-benchmark. It has only been tested on GPU with jax=0.7.2 and CUDA13 so far.

Installation

First install jax with GPU support and then...

pip install -e .[plot]     # With optional plotting support (requires matplotlib)
pip install -e .           # In general

If you don't want to edit the project, you can skip the "-e"

Usage

Define a pytest unit test, e.g. in a file "tests/test_benchmark.py":

import jax
import jax.numpy as jnp

def fft(x): # An example function we want to profile
    return jnp.fft.ifftn(jnp.fft.fftn(x))

def test_fft(jax_bench):  # jax_bench is a fixture that creates a JaxBench object.
    x = jnp.ones((256, 256, 256), dtype=jnp.float32)

    jb = jax_bench(jit_rounds=20, jit_loops=1, jit_warmup=1, eager_rounds=10, eager_warmup=1)
    jb.measure(fn=fft, fn_jit=jax.jit(fft), x=x)

Then simply run your tests as usual with pytest. E.g.

pytest -v

The benchmark results will only be displayed if you use the "-v" (verbose) option. However, they will always be saved to a ".csv" file in a simple to read table -- by default in the .benchmarks directory.

The jb.measure call in the example above will take the following steps:

  • An eager execution warmup run using fn
  • Save the measured peak memory usage. (This measurement has additional requirements as discussed below.)
  • Average the run-time of eager execution over 10 runs using fn.
  • Compile the jitted function using fn_jit
  • Save the predicted memory usage of the jitted function (including peak memory, temporary memory and memory used by folded constants)
  • Average the run-time of the jitted function over 20 runs, using fn_jit

Different stages can be skipped by setting the measurement parameters to 0 or by not passing fn or fn_jit.

jit_loops

The jit_loops parameter can be used to improve profiling -- especially of functions that have a large launch or synchronization overhead. When it is > 1, we will run a compiled for loop with jit_loops iterations jit_rounds+ jit_warmup times. The profiling will still calculate the average run time per call, but each individual measurement will already be averaged of the inner loop iterations so that this gives more accurate results per sample.

Eager exeuction memory

So far I didn't find a reliable way to measure the eager execution peak-memory usage that does not require restarting the measuring process. Therefore, this part of the measurement will be invalid unless you execute with the --forked flag provided by pytest-forked

#pip install pytest-forked  # If you didn't install it already
pytest --forked

However, this may significantly slow down execution and create a bunch of other problems. So I don't recommend using this, unless you really need an eager-memory report.

Optional parameters

All the optional parameters are listed in pytest --help. Since the help file can be a bit overwhelming, you can specifically find all the options defined by pytest-jax-bench (ptjb) as follows:

pytest --help | grep ptjb -A 2

Alternative usage:

Fixtures in pytest don't go too well with syntax highlighting. If you want proper code completion, you can also create the JaxBench object explicitly in your unit test, leading to the exact same result.

from pytest_jax_bench import JaxBench

def test_fft_alt(request):
    x = jnp.ones((256, 256, 256), dtype=jnp.float32)

    jb = JaxBench(request, jit_rounds=20, jit_warmup=1, eager_rounds=10, eager_warmup=1)
    jb.measure(fn_jit=jax.jit(fft), x=x)

Note: We still need to pass the request fixture to the JaxBench object.

Tags

It is possible to do several measurements inside of the same test if you provide a tag to each run:

def rfft(x):
    return jnp.fft.irfftn(jnp.fft.rfftn(x*2.))

def fft(x):
    return jnp.fft.ifftn(jnp.fft.fftn(x*2.))

def test_tags(request):
    x = jnp.ones((256, 256, 256), dtype=jnp.float32)

    jb = JaxBench(request, jit_rounds=10, jit_warmup=1, eager_rounds=5, eager_warmup=1)

    jb.measure(fn=fft, fn_jit=jax.jit(fft), x=x, tag="fft")
    jb.measure(fn=rfft, fn_jit=jax.jit(rfft), x=x, tag="rfft")

So far, this doesn't support eager memory at all (even when using --forked)

In this case tests with several tags will be saved to the same file and they will be plotted together by default. Note that the plots skip some less relevant aspects in this case to keep it simple.

Usage outside of pytest

You can also use the benchmark class independently of pytest:

from pytest_jax_bench import JaxBench

jb = JaxBench(jit_rounds=10, jit_warmup=2, eager_rounds=5, eager_warmup=1)
res, out = jb.measure(fn=rfft, fn_jit=jax.jit(rfft), x=x, write=False)
print(res)

Examples:

For more examples check the examples.

Outputs

Outputs of tests come in three different varieties. (1) Files (.csv) that log the results of the benchmarks (one per test). (2) The terminal output displayed if using "-v" (3) Overview plots that can be created optionally.

The .csv files

Each test creates a csv file in --ptjb-output-dir (defaults to .benchmarks) named after the nodeid of the test. For example this is the file ".benchmarks/tests:test_interface::test_full.csv" that was creates by running the unit tests in this directory several times:

# pytest-jax-bench
# created: 2025-10-16T19:58:11Z
# test_nodeid: tests/test_interface.py::test_full
# backend: gpu
# device: NVIDIA GeForce RTX 4070 Laptop GPU
# First commit: 6bc9501
#  (1) run_id
#  (2) commit
#  (3) commit_run
#  (4) tag
#  (5) compile_ms
#  (6) jit_mean_ms
#  (7) jit_std_ms
#  (8) eager_mean_ms
#  (9) eager_std_ms
# (10) jit_peak_bytes
# (11) jit_constants_bytes
# (12) jit_temporary_bytes
# (13) eager_peak_bytes
# (14) jit_rounds
# (15) jit_warmup
# (16) eager_rounds
# (17) eager_warmup
#        (1)         (2)         (3)         (4)         (5)         (6)         (7)         (8)         (9)        (10)        (11)        (12)        (13)        (14)        (15)        (16)        (17)
           0     6bc9501           0        base       40.65        1.11        0.46        0.78        0.51    25296896           4     8519680    33816576          10           2           5           1
           1    6bc9501+           1        base       46.39        0.99        0.49        1.01        0.52    25296896           4     8519680    33816576          10           2           5           1
           2     76640ba           0        base       43.14        1.14        0.41        0.85        0.57    25296896           4     8519680    33816576          10           2           5           1
           3     76640ba           1        base       43.19        0.91        0.49        0.54        0.02    25296896           4     8519680    33816576          10           2           5           1
 [...]

Each run creates a line in this table. The active git-commit is logged (a "+" indicates a "dirty" commit with some uncommited changes in the directory.) anda per-commit run-id is tracked additionally for convenience. If you want to work with the raw data of these ".csv" files you can use

from pytest_jax_bench import load_bench_data
data = load_bench_data(".benchmarks/tests:test_interface::test_full.csv")

This reads the data into a numpy structured array that you can easily index with strings, e.g.

print(data[-2:]["commit"], data[-2:]["jit_mean_ms"], "+-", data[-2:]["jit_std_ms"])

to print the last two runs' commit, average runtime and measurement uncertainty. Measurements may be set to "np.nan" for invalid timings or "-1" for invalid memory values.

The terminal output

If you run pytest with "-v" option you will get a terminal output similar to this:

pytest-jax-bench terminal example

(Irrelevant output columns may be ommitted.) Each result is compared to a previous run. For now, the comparison run is always the first run that was run within the same active commmit. The intended workflow is to run the benchmark on your clean commit once before you make any changes. Then afterwards you can repeat it frequently while you experiment with the code.

Results that may be particularly relevant are marked in green (improvement) or red (worsened). Benchmarks can fluctuate randomly so don't panic immediately if something flashes up red -- it just means that it might be worth your attention. Also be aware that functions with runtime <~ 1ms are not really well profiled (as you can see in the example above), due to the synchronization overhead -- so prefer to profile runs that take >~ 5 ms.

In the example at hand, you can see that I had improved the function that is called bin test_tags with a very significant reduction in run time and memory usage -- leading to a lot of green marks. However, you can also see that the compile time fluctuated sufficiently for two other tests and the run time for a third one to flash up in color.

Plotting results

You can create plots in two different ways:

  • By directly passing --ptjb-plot to the pytest command
  • By using the ptjb-plot command line tool (supporting the same + some additional options -- see ptjb-plot --help)

Depending on the chosen options, this will either create one big summary plot or an individual plot for each test. For example, this is how one of my individual plots looked, after I had "accidentally" increased the size of my FFT, then panicked and made it too small and finally reverted it to the correct state: pytest-jax-bench plot example

Custom plotting functions

You can define custom plotting functions as follows:

def custom_plot(data):
    fig = plt.figure()
    plt.xlabel("run_id")
    plt.plot(data["run_id"], data["compile_ms"])
    plt.ylabel("compile_ms")
    return fig

@pytest.mark.ptjb(plot=custom_plot)
def test_with_custom_plot(jax_bench):
    x = jnp.ones((128, 128, 128), dtype=jnp.float32)

    jb = jax_bench(jit_rounds=10, jit_warmup=1)
    jb.measure(fn_jit=jax.jit(rfft), x=x)

The plot will be saved automatically in a reasonable path if a figure is returned, but of course you also have the option to save the figure yourself and not return anything.

For parameterized tests it is additionally possible to define a summary plot via plot_summary. In this case the data will contain the merged information of the tests with different parameters (plot would create one plot per parameter instead). For convenience it is also possible to set the only_last flag that will pre-filter the data, so that it only contains the result of the most recent run per parameter/tag.

def custom_plot_par(data):
    fig = plt.figure()
    plt.xlabel("n")
    plt.ylabel("jit_mean_ms")
    plt.plot(data["n"], data["jit_mean_ms"])
    return fig

@pytest.mark.ptjb(plot_summary=custom_plot_par, only_last=True)
@pytest.mark.parametrize("n", [128, 170, 220, 270])
def test_pars_with_custom_plot(jax_bench, n):
    x = jnp.ones((n, n, n), dtype=jnp.float32)

    jb = jax_bench(jit_rounds=10, jit_warmup=1)
    jb.measure(fn_jit=jax.jit(rfft), x=x)

Plotting SVG Graphs

You can autocreate SVG graphs of your jitted function by passing the option "--ptjb-save-graph". A new run's graph will only be saved if it differs from the last saved graph. (The difference detection is a bit challenging and may be slightly noisy.) Be aware that some graphs may get quite large! You can also toggle this per test by setting @pytest.mark.ptjb(save_graph=True)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytest_jax_bench-0.5.0.tar.gz (28.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pytest_jax_bench-0.5.0-py3-none-any.whl (24.1 kB view details)

Uploaded Python 3

File details

Details for the file pytest_jax_bench-0.5.0.tar.gz.

File metadata

  • Download URL: pytest_jax_bench-0.5.0.tar.gz
  • Upload date:
  • Size: 28.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.11

File hashes

Hashes for pytest_jax_bench-0.5.0.tar.gz
Algorithm Hash digest
SHA256 0862aec8142a4bb35c1a66647b771b2a9bd531262319ea0d2aa07a8ab9545eca
MD5 f095f9835dca2c84264951be526cf0e8
BLAKE2b-256 ac543656f955da9c2c60c92f3902fd394ba85be58872b3bc6eaea073fc95e239

See more details on using hashes here.

File details

Details for the file pytest_jax_bench-0.5.0-py3-none-any.whl.

File metadata

File hashes

Hashes for pytest_jax_bench-0.5.0-py3-none-any.whl
Algorithm Hash digest
SHA256 ed29ad41348413abacc81317040492613dc13c3e9d9e0da8c460a1e0f2fcd6b4
MD5 acd04f8a2b0966b94b7a09d3016d84a9
BLAKE2b-256 a0a148312c90246929292e110b7b61afc7d4bbbaf4a6e5cfefdfc8ff2dfc23df

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page