Skip to main content

Static array shape checking for JAX powered by eval_shape

Project description

PyPI Python CI License

jaxtyc

Static array shape checking for JAX powered by jax.eval_shape.

Reads jaxtyping annotations and verifies shapes at analysis time -- no runtime cost, no FLOPs. Each named dimension is assigned a unique prime number, making shape mismatches unambiguous.

VS Code inlay hints showing sharding annotations and shape overlays

CLI diagnostics showing shape mismatches in Claude Code

Watch the demo video

Features

  • Zero runtime cost -- jax.eval_shape only; no arrays allocated, no computation executed
  • Prime-based symbolic shapes -- each dimension name maps to a unique prime (>= 101), so d_in != d_out is guaranteed
  • 10 diagnostic rules -- shape/rank mismatch, cross-function propagation, parameter consistency, tuple return checking, trace errors
  • Inline suppressions -- # jaxtyc: ignore and # jaxtyc: ignore[rule-name]
  • LSP server -- diagnostics, hover, CodeLens, go-to-definition, references, rename, code actions, completion, semantic tokens, inlay hints, signature help, linked editing, folding, call hierarchy
  • LSP multiplexer -- jaxtyc mux runs ty/pyright + jaxtyc behind a single stdio pipe
  • CLI with 4 output formats -- full, concise, json, github (inline PR annotations)
  • Flax NNX + Equinox support -- traces bound methods on module instances
  • Configurable via pyproject.toml -- severity threshold, rule ignoring, file exclusion, einops preferences

Installation

pip install jaxtyc
# or
uv add jaxtyc

Extras:

Extra Installs Use case
jaxtyc[watch] watchfiles jaxtyc watch -- re-check on file save
jaxtyc[flax] flax >=0.10 Flax NNX module tracing
jaxtyc[equinox] equinox >=0.11 Equinox module tracing
jaxtyc[einops] einops >=0.8 einops-style fix suggestions + inlay hints with pattern dim names
jaxtyc[all] All of the above Everything

Quick Start

# model.py
import jax.numpy as jnp
from jaxtyping import Array, Float

def linear(
    x: Float[Array, "batch seq d_in"],
    w: Float[Array, "d_in d_out"],
) -> Float[Array, "batch seq d_out"]:
    return jnp.matmul(x, w.T)  # Bug: .T swaps dims, produces (batch, seq, d_in)
$ jaxtyc check model.py
model.py:8:0: error[shape-mismatch]
  Shape mismatch in return of `linear`
    Expected: (batch, seq, d_out)
    Got:      (batch, seq, d_in)

Found 1 error(s) in 1 function(s) checked (0.03s)

Editor Integration

VS Code

Install the jaxtyc extension:

cd editors/vscode && npm install && npm run bundle
npx @vscode/vsce package --allow-missing-repository
code --install-extension jaxtyc-*.vsix

Or use the justfile: just vscode-update

The extension auto-discovers your Python environment (.venv, VIRTUAL_ENV, or jaxtyc on PATH) and starts the LSP server automatically. Supports multi-root workspaces with per-folder LSP clients. Includes jaxtyping snippets, a trace visualization webview, and a status bar quick pick menu.

Other Editors

jaxtyc works in any editor that supports LSP (Neovim, Helix, etc.). See the editor setup docs for configuration.

CLI

jaxtyc check <paths>...          # Shape-check files or directories
jaxtyc trace <file.py::func>     # Trace intermediate shapes through a function
jaxtyc watch <paths>...          # Watch and re-check on change
jaxtyc lsp                       # Start the LSP server (stdio)
jaxtyc mux                       # Start the LSP multiplexer (ty/pyright + jaxtyc)
jaxtyc version                   # Print version

Documentation

Full docs at beegass.github.io/jaxtyc.

Contributing

Contributions are welcome! See CONTRIBUTING.md for guidelines.

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxtyc-0.7.2.tar.gz (1.9 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jaxtyc-0.7.2-py3-none-any.whl (87.4 kB view details)

Uploaded Python 3

File details

Details for the file jaxtyc-0.7.2.tar.gz.

File metadata

  • Download URL: jaxtyc-0.7.2.tar.gz
  • Upload date:
  • Size: 1.9 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.11.3 {"installer":{"name":"uv","version":"0.11.3","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for jaxtyc-0.7.2.tar.gz
Algorithm Hash digest
SHA256 0055b7942cf929fbc74026b5bc95deaf0979fe901a3114a31317ab489dae4a67
MD5 6c9a576df8dec77b2093259495dec92b
BLAKE2b-256 641844905291ae7ae9550ba9bbd5b442f7653ac583979fe12e3677b8ad0c921b

See more details on using hashes here.

File details

Details for the file jaxtyc-0.7.2-py3-none-any.whl.

File metadata

  • Download URL: jaxtyc-0.7.2-py3-none-any.whl
  • Upload date:
  • Size: 87.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.11.3 {"installer":{"name":"uv","version":"0.11.3","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for jaxtyc-0.7.2-py3-none-any.whl
Algorithm Hash digest
SHA256 176d59e0073f5c5b09889c87e52a220312920caa938cdef5e5241c5574c97fbb
MD5 7e92f0da0ad11968a91fab859829ea1d
BLAKE2b-256 590ef2548c2a48ba92c1036315bc94e2a3599397409338140a3768960e7c528d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page