Export PyTorch models as readable, editable Python code
Project description
torch.export_python
This is an experimental mode of use of the PyTorch compiler stack, where the output artifacts of the compiler are entirely readable Python code. This output can then be checked into a source repository and edited by hand; or we also try to make it simple to regenerate the code if you upgraded the compiler or modified the original Python plain source code.
Why might you want something like this? Lots of reasons:
-
Precompilation. You don't want to have to run the compiler every time you run your model. The exported Python code here has no runtime compiler dependency.
-
Transparency. The code generated by the compiler is similar to the code you would have written if you optimized by hand. So you don't have to trust the compiler; you can audit the output only and trust that only.
-
Portability. You don't have to regenerate the exported Python code if you don't want to; you can upgrade Python separately from upgrading your kernels. The generated code uses only stable PyTorch APIs and is portable across versions.
As of right now, we intend for export python to cover these layers of the compiler stack:
-
AOTAutograd. We can take forward-only PyTorch code and create a custom autograd function that binds together a forward and backward implementation. This is blocked on AOTAutograd codegen of runtime wrappers, see:
https://github.com/pytorch/pytorch/pull/176741 https://github.com/pytorch/pytorch/pull/179599 https://github.com/pytorch/pytorch/pull/179061 https://github.com/pytorch/pytorch/pull/178927 https://github.com/pytorch/pytorch/pull/178675
-
Inductor. We can take PyTorch code and turn it into fused Triton kernels.
In future work, we may also extend to support Dynamo for handling complicated input/output conventions / Python side effect mutation.
Usage
Basic export (static shapes)
import torch
from torch_export_python import torch_export_python
def rms_norm(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-6) * weight
args = (torch.randn(2, 64, 256, device="cuda"),
torch.randn(256, device="cuda"))
kernel = torch_export_python(rms_norm, args)
Dynamic shapes
Use torch.export.Dim to mark dimensions that can vary at runtime:
from torch.export import Dim
B = Dim("B")
T = Dim("T")
kernel = torch_export_python(rms_norm, args, dynamic_shapes={
"x": (B, T, None),
"weight": (None,),
})
Saving and loading
The generated source depends only on torch, triton, and
triton.language at runtime — no torch._inductor imports needed.
# save.py — write the generated code to a module
from pathlib import Path
Path("my_package/my_kernel.py").write_text(kernel.source)
# run.py — import and run like any other Python module
from my_package.my_kernel import call
result = call(x, weight)
Running in-memory
For quick iteration without saving to disk:
result = kernel.run(*args) # convenient: keeps your references live
result = kernel.boxed_run(list(args)) # explicit: input list is consumed
API Reference
torch_export_python(fn, args, *, dynamic_shapes=None) -> ExportedKernel— end-to-end pipeline: trace viatorch.export, compile through Inductor, clean up the output.export_and_codegen(fn, args, *, dynamic_shapes=None) -> str— stage 1 only: returns raw Inductor source before cleanup.postprocess(raw_src, *, dynamic_shape_names=None, tensor_arg_count=0) -> ExportedKernel— stage 2 only: clean raw Inductor output and wrap in anExportedKernel.ExportedKernel.source— the generated Python source code string.ExportedKernel.run(*args)— execute the kernel with the original function's arguments.ExportedKernel.boxed_run(args_list)— execute with boxed calling convention (input list is consumed).ExportedKernel.dynamic_shape_names— list of symbolic dimension names (e.g.["B", "T"]).ExportedKernel.tensor_arg_count— number of tensor arguments expected.
License
BSD 3-Clause License. See LICENSE for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torch_export_python-0.1.0.tar.gz.
File metadata
- Download URL: torch_export_python-0.1.0.tar.gz
- Upload date:
- Size: 19.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.9 {"installer":{"name":"uv","version":"0.9.9"},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4a1de18c3dcec9105cc6034b6c33520613fcff45521b0f39cb003119c0cd511e
|
|
| MD5 |
d6290e583df91cd70d1dcde178c9fae1
|
|
| BLAKE2b-256 |
1cb3e67c338200a07e4914dc6540c57860a9903ce9bac9c749f1728c69520c2c
|
File details
Details for the file torch_export_python-0.1.0-py3-none-any.whl.
File metadata
- Download URL: torch_export_python-0.1.0-py3-none-any.whl
- Upload date:
- Size: 17.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.9 {"installer":{"name":"uv","version":"0.9.9"},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e1f48cf97a0c13b527439f86c9cd09bad7b12ab860bb3b880ffbec388629e6af
|
|
| MD5 |
90153f3c094bc2d7282fecc2f681625d
|
|
| BLAKE2b-256 |
17216e76135b6b24cc194e2605b8e8f9986dfb043053d991d972f9c901f8e8ac
|