Skip to main content

JAX Monitoring to send JAX metrics to Google Cloud Monitoring

Project description

JAX Cloud Monitoring

This library provides Google Cloud Monitoring (Stackdriver) integration for JAX. It captures JAX monitoring events (like compilation time) and exports them as custom metrics to Google Cloud Monitoring.

Key Features

  • Non-blocking Export (Multiprocessing): metric export happens in a separate background process (multiprocessing), ensuring that network I/O to the Cloud Monitoring API does not block your JAX workload or contend for the Global Interpreter Lock (GIL).
  • Easy Integration: hooks into jax.monitoring with a simple registration call.
  • Configurable: customize metric prefixes, monitored resources, and labels.

Installation

pip install jax-monitoring

Note: You may need to install from source or a private registry if this package is not published to PyPI.

Usage

import jax
import jax_monitoring as jm

def main():
    jm.init(
        job_name="my-training-job",
    )

    # Run jax inside jit to capture metrics compilation metrics.
    # Check for backend_compile_duration metric in Cloud Monitoring > Metrics Explorer.
    x = jax.jit(lambda x: x + x)(jax.numpy.ones((1000, 1000)))
    x.block_until_ready()
    
if __name__ == "__main__":
    main()

Configuration

You can configure the behavior using jm.init():

  • project_id (str): GCP Project ID. If not provided, the library will attempt to infer it from the environment.
  • metric_prefix (str): Prefix for all exported metrics. Default: custom.googleapis.com/jax/monitoring.
  • job_name (str): A label added to all metrics to identify the job. Default: jax_job.
  • monitored_resource_type (str): The Stackdriver monitored resource type. Default: global or gce_instance when running on GCE.
  • monitored_resource_labels (dict): Labels for the monitored resource.

How it Works

When jm.init() is called, the library:

  1. Starts a background multiprocessing.Process.
  2. Registers a callback with jax.monitoring.
  3. When JAX triggers an event (e.g., a compilation finishes), the callback puts the event data into a multiprocessing.Queue.
  4. The background worker picks up events from the queue and batches them to the Google Cloud Monitoring API using the efficient create_time_series call.

This architecture ensures that the main training loop is never blocked by HTTP requests to the monitoring backend.

Disclaimer

This is a proof of concept and is not production ready. You may get a huge cloud bill and you are fully responsible for the usage of the cloud resources.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_monitoring-0.1.0.tar.gz (5.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_monitoring-0.1.0-py3-none-any.whl (7.9 kB view details)

Uploaded Python 3

File details

Details for the file jax_monitoring-0.1.0.tar.gz.

File metadata

  • Download URL: jax_monitoring-0.1.0.tar.gz
  • Upload date:
  • Size: 5.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.0

File hashes

Hashes for jax_monitoring-0.1.0.tar.gz
Algorithm Hash digest
SHA256 d9ff2320b647546117e6f0d00019da15c4aa65b509815ecc51363266fbd04bde
MD5 8bc750cd3f8414dc57f7c135d6ee0e84
BLAKE2b-256 87bb229b2a9a2716798828e0411b40394a6dcebf12779975388eec96e0fda11b

See more details on using hashes here.

File details

Details for the file jax_monitoring-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for jax_monitoring-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 b40957f1f4959a3052ca2240cbae079649c36045fb1a47daff96d3c373d5ed1f
MD5 e90aa0a49cd66c343c04193f04d76421
BLAKE2b-256 71bc560f1807e6713b4e1237b36d1b9ed8442148c9965a5840fb85646765f94f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page