Skip to main content

JAX Monitoring to send JAX metrics to Google Cloud Monitoring

Project description

JAX Cloud Monitoring

This library provides Google Cloud Monitoring (Stackdriver) integration for JAX. It captures JAX monitoring events (like compilation time) and exports them as custom metrics to Google Cloud Monitoring.

Key Features

  • Non-blocking Export (Multiprocessing): metric export happens in a separate background process (multiprocessing), ensuring that network I/O to the Cloud Monitoring API does not block your JAX workload or contend for the Global Interpreter Lock (GIL).
  • Easy Integration: hooks into jax.monitoring with a simple registration call.
  • Configurable: customize metric prefixes, monitored resources, and labels.

Installation

pip install jax-monitoring

Note: You may need to install from source or a private registry if this package is not published to PyPI.

Usage

import jax
import jax_monitoring as jm

def main():
    jm.init(
        job_name="my-training-job",
    )

    # Run jax inside jit to capture metrics compilation metrics.
    # Check for backend_compile_duration metric in Cloud Monitoring > Metrics Explorer.
    x = jax.jit(lambda x: x + x)(jax.numpy.ones((1000, 1000)))
    x.block_until_ready()
    
if __name__ == "__main__":
    main()

Configuration

You can configure the behavior using jm.init():

  • project_id (str): GCP Project ID. If not provided, the library will attempt to infer it from the environment.
  • metric_prefix (str): Prefix for all exported metrics. Default: custom.googleapis.com/jax/monitoring.
  • job_name (str): A label added to all metrics to identify the job. Default: jax_job.
  • monitored_resource_type (str): The Stackdriver monitored resource type. Default: global or gce_instance when running on GCE.
  • monitored_resource_labels (dict): Labels for the monitored resource.

How it Works

When jm.init() is called, the library:

  1. Starts a background multiprocessing.Process.
  2. Registers a callback with jax.monitoring.
  3. When JAX triggers an event (e.g., a compilation finishes), the callback puts the event data into a multiprocessing.Queue.
  4. The background worker picks up events from the queue and batches them to the Google Cloud Monitoring API using the efficient create_time_series call.

This architecture ensures that the main training loop is never blocked by HTTP requests to the monitoring backend.

Disclaimer

This is a proof of concept and is not production ready. You may get a huge cloud bill and you are fully responsible for the usage of the cloud resources.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_monitoring-0.1.1.tar.gz (5.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_monitoring-0.1.1-py3-none-any.whl (7.2 kB view details)

Uploaded Python 3

File details

Details for the file jax_monitoring-0.1.1.tar.gz.

File metadata

  • Download URL: jax_monitoring-0.1.1.tar.gz
  • Upload date:
  • Size: 5.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.0

File hashes

Hashes for jax_monitoring-0.1.1.tar.gz
Algorithm Hash digest
SHA256 397a43e64387064debf3dcb34ecaeb9abe439ef8e65d765f988552c3ab8d8a6e
MD5 19a4b6d02cf58d72c33e0bb2a41841a8
BLAKE2b-256 f9ade087a0fccbf1b8542c6c6ee5221e97e039a98980499154b44041a9c16b38

See more details on using hashes here.

File details

Details for the file jax_monitoring-0.1.1-py3-none-any.whl.

File metadata

File hashes

Hashes for jax_monitoring-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 411fa5903ee5b11d85169cc3b0c0bdb041797c83c6a52c3446a670ecd9bc6c06
MD5 ce4dd9665b117fcfe13ca6a0a16c2a9e
BLAKE2b-256 bdc35ba8d7dc7d6f12812bb2d845aa0cef81673a7b77e42ce7ba214fcdcac60d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page