Skip to main content

Static AST-based PyTorch tensor shape analysis.

Project description

TorchShapeFlow

CI PyPI Python License: MIT

TorchShapeFlow is a static, AST-based shape analyzer for PyTorch. It reads your Python source, infers tensor shapes from Annotated[..., Shape(...)] contracts, and reports mismatches as structured diagnostics. No execution required.

from typing import Annotated
import torch
from torchshapeflow import Shape

def attention_scores(
    q: Annotated[torch.Tensor, Shape("B", "H", "T", "D")],
    k: Annotated[torch.Tensor, Shape("B", "H", "T", "D")],
) -> Annotated[torch.Tensor, Shape("B", "H", "T", "T")]:
    return q @ k.transpose(-2, -1)
$ tsf check mymodel.py
All clean (1 file checked)

Philosophy

TorchShapeFlow is annotation-first and symbolic-first.

  • You declare tensor shape contracts with Annotated[torch.Tensor, Shape(...)].
  • Symbolic dimensions like "B", "T", and "D" are the default path for config-driven model code.
  • Integer dimensions are still useful for fixed semantics like RGB channels or known embedding widths.
  • When inference is not possible, the analyzer degrades visibly instead of guessing.

If Pydantic gives structure to data boundaries, TorchShapeFlow aims to do the same for tensor-shape boundaries in deep learning code.

Install

In Claude Code (two commands, no config-file editing):

/plugin marketplace add Davidxswang/torchshapeflow
/plugin install torchshapeflow@torchshapeflow

The first command registers this repo as a plugin marketplace (pulling from main by default). The second installs the torchshapeflow plugin from that marketplace, which wires in an MCP server, an agent skill, and a post-edit hook — your Claude Code then knows how to run tsf check, interpret the structured diagnostics, and propose annotations. No manual .mcp.json editing required.

As a plain Python package (for CLI use or other agent runtimes):

pip install torchshapeflow

Documentation

Full docs at davidxswang.github.io/torchshapeflow

Contributing

git clone https://github.com/Davidxswang/torchshapeflow
cd torchshapeflow
make install   # uv sync --extra dev
make check     # format + lint + typecheck + tests

If you want to execute the example PyTorch scripts in examples/, install the separate examples extra:

uv sync --extra dev --extra examples

See docs/development.md for the full development guide: all make targets, CI workflow descriptions, and how to add new operators.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchshapeflow-0.7.2.tar.gz (104.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchshapeflow-0.7.2-py3-none-any.whl (61.1 kB view details)

Uploaded Python 3

File details

Details for the file torchshapeflow-0.7.2.tar.gz.

File metadata

  • Download URL: torchshapeflow-0.7.2.tar.gz
  • Upload date:
  • Size: 104.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for torchshapeflow-0.7.2.tar.gz
Algorithm Hash digest
SHA256 445527dfefce9301c06bc9947bdcff7fa72c9b4ec9bfeae5de1686fbf28b6506
MD5 fd8c90c9f8f44096238d1a0e16517742
BLAKE2b-256 fc7d8d814c459905c010681a07d2bca2aee8677c18256af8af3784777091a4f7

See more details on using hashes here.

Provenance

The following attestation bundles were made for torchshapeflow-0.7.2.tar.gz:

Publisher: release.yml on Davidxswang/torchshapeflow

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file torchshapeflow-0.7.2-py3-none-any.whl.

File metadata

  • Download URL: torchshapeflow-0.7.2-py3-none-any.whl
  • Upload date:
  • Size: 61.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for torchshapeflow-0.7.2-py3-none-any.whl
Algorithm Hash digest
SHA256 d8f4dd08f1426b64dca54432448690f4e30913dadd7fd0247c2bcec4868a4442
MD5 d1018e72b05766ad73015015461e7144
BLAKE2b-256 c3dcda703405679c7bcdd58c28a8a5ff055933ac2c9d920bbae50c239332e6b9

See more details on using hashes here.

Provenance

The following attestation bundles were made for torchshapeflow-0.7.2-py3-none-any.whl:

Publisher: release.yml on Davidxswang/torchshapeflow

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page