Static array shape checking for JAX powered by eval_shape
Project description
jaxtyc
Static array shape checking for JAX powered by jax.eval_shape.
Reads jaxtyping annotations and verifies shapes at analysis time -- no runtime cost, no FLOPs. Each named dimension is assigned a unique prime number, making shape mismatches unambiguous.
Features
- Zero runtime cost --
jax.eval_shapeonly; no arrays allocated, no computation executed - Prime-based symbolic shapes -- each dimension name maps to a unique prime (>= 101), so
d_in != d_outis guaranteed - 10 diagnostic rules -- shape/rank mismatch, cross-function propagation, parameter consistency, tuple return checking, trace errors
- Inline suppressions --
# jaxtyc: ignoreand# jaxtyc: ignore[rule-name] - LSP server -- diagnostics, hover, CodeLens, go-to-definition, references, rename, code actions, completion, semantic tokens, inlay hints, signature help, linked editing, folding, call hierarchy
- LSP multiplexer --
jaxtyc muxruns ty/pyright + jaxtyc behind a single stdio pipe - CLI with 4 output formats --
full,concise,json,github(inline PR annotations) - Flax NNX + Equinox support -- traces bound methods on module instances
- Configurable via
pyproject.toml-- severity threshold, rule ignoring, file exclusion, einops preferences
Installation
pip install jaxtyc
# or
uv add jaxtyc
Extras:
| Extra | Installs | Use case |
|---|---|---|
jaxtyc[watch] |
watchfiles |
jaxtyc watch -- re-check on file save |
jaxtyc[flax] |
flax >=0.10 |
Flax NNX module tracing |
jaxtyc[equinox] |
equinox >=0.11 |
Equinox module tracing |
jaxtyc[einops] |
einops >=0.8 |
einops-style fix suggestions + inlay hints with pattern dim names |
jaxtyc[all] |
All of the above | Everything |
Quick Start
# model.py
import jax.numpy as jnp
from jaxtyping import Array, Float
def linear(
x: Float[Array, "batch seq d_in"],
w: Float[Array, "d_in d_out"],
) -> Float[Array, "batch seq d_out"]:
return jnp.matmul(x, w.T) # Bug: .T swaps dims, produces (batch, seq, d_in)
$ jaxtyc check model.py
model.py:8:0: error[shape-mismatch]
Shape mismatch in return of `linear`
Expected: (batch, seq, d_out)
Got: (batch, seq, d_in)
Found 1 error(s) in 1 function(s) checked (0.03s)
Editor Integration
VS Code
Install the jaxtyc extension:
cd editors/vscode && npm install && npm run bundle
npx @vscode/vsce package --allow-missing-repository
code --install-extension jaxtyc-*.vsix
Or use the justfile: just vscode-update
The extension auto-discovers your Python environment (.venv, VIRTUAL_ENV, or jaxtyc on PATH) and starts the LSP server automatically. Supports multi-root workspaces with per-folder LSP clients. Includes jaxtyping snippets, a trace visualization webview, and a status bar quick pick menu.
Other Editors
jaxtyc works in any editor that supports LSP (Neovim, Helix, etc.). See the editor setup docs for configuration.
CLI
jaxtyc check <paths>... # Shape-check files or directories
jaxtyc trace <file.py::func> # Trace intermediate shapes through a function
jaxtyc watch <paths>... # Watch and re-check on change
jaxtyc lsp # Start the LSP server (stdio)
jaxtyc mux # Start the LSP multiplexer (ty/pyright + jaxtyc)
jaxtyc version # Print version
Documentation
Full docs at beegass.github.io/jaxtyc.
Contributing
Contributions are welcome! See CONTRIBUTING.md for guidelines.
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jaxtyc-0.7.2.tar.gz.
File metadata
- Download URL: jaxtyc-0.7.2.tar.gz
- Upload date:
- Size: 1.9 MB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.11.3 {"installer":{"name":"uv","version":"0.11.3","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0055b7942cf929fbc74026b5bc95deaf0979fe901a3114a31317ab489dae4a67
|
|
| MD5 |
6c9a576df8dec77b2093259495dec92b
|
|
| BLAKE2b-256 |
641844905291ae7ae9550ba9bbd5b442f7653ac583979fe12e3677b8ad0c921b
|
File details
Details for the file jaxtyc-0.7.2-py3-none-any.whl.
File metadata
- Download URL: jaxtyc-0.7.2-py3-none-any.whl
- Upload date:
- Size: 87.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.11.3 {"installer":{"name":"uv","version":"0.11.3","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
176d59e0073f5c5b09889c87e52a220312920caa938cdef5e5241c5574c97fbb
|
|
| MD5 |
7e92f0da0ad11968a91fab859829ea1d
|
|
| BLAKE2b-256 |
590ef2548c2a48ba92c1036315bc94e2a3599397409338140a3768960e7c528d
|