MCP server for JAX documentation
Project description
jax-mcp
MCP (Model Context Protocol) server for JAX documentation.
Enables LLMs to access up-to-date JAX documentation and validate generated code.
Installation
pip install -e .
Or run directly:
python -m jax_mcp
Usage with Claude Code
# Add as MCP server
claude mcp add -t stdio -s user jax -- python -m jax_mcp
# Or if installed globally
claude mcp add -t stdio -s user jax -- jax-mcp
Configuration
| Environment Variable | Default | Description |
|---|---|---|
JAX_DOCS_PATH |
(none) | Path to local JAX docs directory (offline mode) |
JAX_MCP_CACHE_DIR |
~/.cache/jax-mcp |
Cache directory for online mode |
JAX_MCP_CACHE_TTL |
24 |
Cache TTL in hours |
JAX_MCP_NO_CACHE |
0 |
Set to 1 to disable caching |
Offline Mode
Point to a local JAX clone for offline access:
export JAX_DOCS_PATH=/path/to/jax/docs
python -m jax_mcp
Online Mode (Default)
Fetches docs from GitHub with local caching:
python -m jax_mcp
# Fetches from: raw.githubusercontent.com/google/jax/main/docs/
Available Tools
| Tool | Description |
|---|---|
list-sections |
List all documentation sections by category |
get-documentation |
Fetch specific documentation content |
jax-checker |
Validate JAX code for common gotchas |
Documentation Categories
- concepts: Core JAX concepts (pytrees, transformations, tracing)
- gotchas: Common mistakes and how to avoid them
- transforms: jit, vmap, grad, scan patterns
- advanced: Distributed computing, custom pytrees
- performance: GPU tips, profiling, benchmarking
- api: Module overviews (jax.numpy, jax.lax, jax.random)
- examples: Practical code examples
JAX Checker
The jax-checker tool catches common JAX mistakes:
- In-place array mutations (
array[idx] = value) - Side effects in jitted functions (print, globals)
- PRNG key reuse without splitting
- Python control flow in traced code
- Missing
block_until_ready()for benchmarks - Float64 usage without config
Development
# Install dev dependencies
pip install -e ".[dev]"
# Run tests
pytest
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jax_mcp-0.1.0.tar.gz.
File metadata
- Download URL: jax_mcp-0.1.0.tar.gz
- Upload date:
- Size: 66.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
21667b504ba291e1910efd3f1bffc4c33b65aa9df31cb9460b2b6cb7d98fefd2
|
|
| MD5 |
5c2e9e60ff7f64ef04dda21920dd1758
|
|
| BLAKE2b-256 |
4a4da3556390db12c21e992e93913ef40c3d6c7769c2cd6243b2b7eccafa76a4
|
File details
Details for the file jax_mcp-0.1.0-py3-none-any.whl.
File metadata
- Download URL: jax_mcp-0.1.0-py3-none-any.whl
- Upload date:
- Size: 13.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7722d21a2cbfc51319bec47cbdb9cdb2f5dea5d08312054f16daa001517f5073
|
|
| MD5 |
3a556ab6d2a8df75ef4dbee96eb784f9
|
|
| BLAKE2b-256 |
27964387e701a8399e12d50c709ac7b552a8d42c0de8623118eb5a94fa161f71
|