Skip to main content

Export PyTorch models as readable, editable Python code

Project description

torch.export_python

This is an experimental mode of use of the PyTorch compiler stack, where the output artifacts of the compiler are entirely readable Python code. This output can then be checked into a source repository and edited by hand; or we also try to make it simple to regenerate the code if you upgraded the compiler or modified the original Python plain source code.

Why might you want something like this? Lots of reasons:

  • Precompilation. You don't want to have to run the compiler every time you run your model. The exported Python code here has no runtime compiler dependency.

  • Transparency. The code generated by the compiler is similar to the code you would have written if you optimized by hand. So you don't have to trust the compiler; you can audit the output only and trust that only.

  • Portability. You don't have to regenerate the exported Python code if you don't want to; you can upgrade Python separately from upgrading your kernels. The generated code uses only stable PyTorch APIs and is portable across versions.

As of right now, we intend for export python to cover these layers of the compiler stack:

In future work, we may also extend to support Dynamo for handling complicated input/output conventions / Python side effect mutation.

Usage

Basic export (static shapes)

import torch
from torch_export_python import torch_export_python

def rms_norm(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
    return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-6) * weight

args = (torch.randn(2, 64, 256, device="cuda"),
        torch.randn(256, device="cuda"))

kernel = torch_export_python(rms_norm, args)

Dynamic shapes

Use torch.export.Dim to mark dimensions that can vary at runtime:

from torch.export import Dim

B = Dim("B")
T = Dim("T")
kernel = torch_export_python(rms_norm, args, dynamic_shapes={
    "x": (B, T, None),
    "weight": (None,),
})

Saving and loading

The generated source depends only on torch, triton, and triton.language at runtime — no torch._inductor imports needed.

# save.py — write the generated code to a module
from pathlib import Path
Path("my_package/my_kernel.py").write_text(kernel.source)
# run.py — import and run like any other Python module
from my_package.my_kernel import call
result = call(x, weight)

Running in-memory

For quick iteration without saving to disk:

result = kernel.run(*args)             # convenient: keeps your references live
result = kernel.boxed_run(list(args))  # explicit: input list is consumed

API Reference

  • torch_export_python(fn, args, *, dynamic_shapes=None) -> ExportedKernel — end-to-end pipeline: trace via torch.export, compile through Inductor, clean up the output.
  • export_and_codegen(fn, args, *, dynamic_shapes=None) -> str — stage 1 only: returns raw Inductor source before cleanup.
  • postprocess(raw_src, *, dynamic_shape_names=None, tensor_arg_count=0) -> ExportedKernel — stage 2 only: clean raw Inductor output and wrap in an ExportedKernel.
  • ExportedKernel.source — the generated Python source code string.
  • ExportedKernel.run(*args) — execute the kernel with the original function's arguments.
  • ExportedKernel.boxed_run(args_list) — execute with boxed calling convention (input list is consumed).
  • ExportedKernel.dynamic_shape_names — list of symbolic dimension names (e.g. ["B", "T"]).
  • ExportedKernel.tensor_arg_count — number of tensor arguments expected.

License

BSD 3-Clause License. See LICENSE for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_export_python-0.1.0.tar.gz (19.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_export_python-0.1.0-py3-none-any.whl (17.3 kB view details)

Uploaded Python 3

File details

Details for the file torch_export_python-0.1.0.tar.gz.

File metadata

  • Download URL: torch_export_python-0.1.0.tar.gz
  • Upload date:
  • Size: 19.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.9 {"installer":{"name":"uv","version":"0.9.9"},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for torch_export_python-0.1.0.tar.gz
Algorithm Hash digest
SHA256 4a1de18c3dcec9105cc6034b6c33520613fcff45521b0f39cb003119c0cd511e
MD5 d6290e583df91cd70d1dcde178c9fae1
BLAKE2b-256 1cb3e67c338200a07e4914dc6540c57860a9903ce9bac9c749f1728c69520c2c

See more details on using hashes here.

File details

Details for the file torch_export_python-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: torch_export_python-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 17.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.9 {"installer":{"name":"uv","version":"0.9.9"},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for torch_export_python-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 e1f48cf97a0c13b527439f86c9cd09bad7b12ab860bb3b880ffbec388629e6af
MD5 90153f3c094bc2d7282fecc2f681625d
BLAKE2b-256 17216e76135b6b24cc194e2605b8e8f9986dfb043053d991d972f9c901f8e8ac

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page