Skip to main content

MCP server for JAX documentation

Project description

jax-mcp

MCP (Model Context Protocol) server for JAX documentation.

Enables LLMs to access up-to-date JAX documentation and validate generated code.

Installation

pip install -e .

Or run directly:

python -m jax_mcp

Usage with Claude Code

# Add as MCP server
claude mcp add -t stdio -s user jax -- python -m jax_mcp

# Or if installed globally
claude mcp add -t stdio -s user jax -- jax-mcp

Configuration

Environment Variable Default Description
JAX_DOCS_PATH (none) Path to local JAX docs directory (offline mode)
JAX_MCP_CACHE_DIR ~/.cache/jax-mcp Cache directory for online mode
JAX_MCP_CACHE_TTL 24 Cache TTL in hours
JAX_MCP_NO_CACHE 0 Set to 1 to disable caching

Offline Mode

Point to a local JAX clone for offline access:

export JAX_DOCS_PATH=/path/to/jax/docs
python -m jax_mcp

Online Mode (Default)

Fetches docs from GitHub with local caching:

python -m jax_mcp
# Fetches from: raw.githubusercontent.com/google/jax/main/docs/

Available Tools

Tool Description
list-sections List all documentation sections by category
get-documentation Fetch specific documentation content
jax-checker Validate JAX code for common gotchas

Documentation Categories

  • concepts: Core JAX concepts (pytrees, transformations, tracing)
  • gotchas: Common mistakes and how to avoid them
  • transforms: jit, vmap, grad, scan patterns
  • advanced: Distributed computing, custom pytrees
  • performance: GPU tips, profiling, benchmarking
  • api: Module overviews (jax.numpy, jax.lax, jax.random)
  • examples: Practical code examples

JAX Checker

The jax-checker tool catches common JAX mistakes:

  • In-place array mutations (array[idx] = value)
  • Side effects in jitted functions (print, globals)
  • PRNG key reuse without splitting
  • Python control flow in traced code
  • Missing block_until_ready() for benchmarks
  • Float64 usage without config

Development

# Install dev dependencies
pip install -e ".[dev]"

# Run tests
pytest

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_mcp-0.1.0.tar.gz (66.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_mcp-0.1.0-py3-none-any.whl (13.3 kB view details)

Uploaded Python 3

File details

Details for the file jax_mcp-0.1.0.tar.gz.

File metadata

  • Download URL: jax_mcp-0.1.0.tar.gz
  • Upload date:
  • Size: 66.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.5

File hashes

Hashes for jax_mcp-0.1.0.tar.gz
Algorithm Hash digest
SHA256 21667b504ba291e1910efd3f1bffc4c33b65aa9df31cb9460b2b6cb7d98fefd2
MD5 5c2e9e60ff7f64ef04dda21920dd1758
BLAKE2b-256 4a4da3556390db12c21e992e93913ef40c3d6c7769c2cd6243b2b7eccafa76a4

See more details on using hashes here.

File details

Details for the file jax_mcp-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: jax_mcp-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 13.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.5

File hashes

Hashes for jax_mcp-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 7722d21a2cbfc51319bec47cbdb9cdb2f5dea5d08312054f16daa001517f5073
MD5 3a556ab6d2a8df75ef4dbee96eb784f9
BLAKE2b-256 27964387e701a8399e12d50c709ac7b552a8d42c0de8623118eb5a94fa161f71

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page