JAX Monitoring to send JAX metrics to Google Cloud Monitoring
Project description
JAX Cloud Monitoring
This library provides Google Cloud Monitoring (Stackdriver) integration for JAX. It captures JAX monitoring events (like compilation time) and exports them as custom metrics to Google Cloud Monitoring.
Key Features
- Non-blocking Export (Multiprocessing): metric export happens in a separate background process (
multiprocessing), ensuring that network I/O to the Cloud Monitoring API does not block your JAX workload or contend for the Global Interpreter Lock (GIL). - Easy Integration: hooks into
jax.monitoringwith a simple registration call. - Configurable: customize metric prefixes, monitored resources, and labels.
Installation
pip install jax-monitoring
Note: You may need to install from source or a private registry if this package is not published to PyPI.
Usage
import jax
import jax_monitoring as jm
def main():
jm.init(
job_name="my-training-job",
)
# Run jax inside jit to capture metrics compilation metrics.
# Check for backend_compile_duration metric in Cloud Monitoring > Metrics Explorer.
x = jax.jit(lambda x: x + x)(jax.numpy.ones((1000, 1000)))
x.block_until_ready()
if __name__ == "__main__":
main()
Configuration
You can configure the behavior using jm.init():
project_id(str): GCP Project ID. If not provided, the library will attempt to infer it from the environment.metric_prefix(str): Prefix for all exported metrics. Default:custom.googleapis.com/jax/monitoring.job_name(str): A label added to all metrics to identify the job. Default:jax_job.monitored_resource_type(str): The Stackdriver monitored resource type. Default:globalorgce_instancewhen running on GCE.monitored_resource_labels(dict): Labels for the monitored resource.
How it Works
When jm.init() is called, the library:
- Starts a background
multiprocessing.Process. - Registers a callback with
jax.monitoring. - When JAX triggers an event (e.g., a compilation finishes), the callback puts the event data into a
multiprocessing.Queue. - The background worker picks up events from the queue and batches them to the Google Cloud Monitoring API using the efficient
create_time_seriescall.
This architecture ensures that the main training loop is never blocked by HTTP requests to the monitoring backend.
Disclaimer
This is a proof of concept and is not production ready. You may get a huge cloud bill and you are fully responsible for the usage of the cloud resources.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jax_monitoring-0.1.1.tar.gz.
File metadata
- Download URL: jax_monitoring-0.1.1.tar.gz
- Upload date:
- Size: 5.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
397a43e64387064debf3dcb34ecaeb9abe439ef8e65d765f988552c3ab8d8a6e
|
|
| MD5 |
19a4b6d02cf58d72c33e0bb2a41841a8
|
|
| BLAKE2b-256 |
f9ade087a0fccbf1b8542c6c6ee5221e97e039a98980499154b44041a9c16b38
|
File details
Details for the file jax_monitoring-0.1.1-py3-none-any.whl.
File metadata
- Download URL: jax_monitoring-0.1.1-py3-none-any.whl
- Upload date:
- Size: 7.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
411fa5903ee5b11d85169cc3b0c0bdb041797c83c6a52c3446a670ecd9bc6c06
|
|
| MD5 |
ce4dd9665b117fcfe13ca6a0a16c2a9e
|
|
| BLAKE2b-256 |
bdc35ba8d7dc7d6f12812bb2d845aa0cef81673a7b77e42ce7ba214fcdcac60d
|