Skip to main content

Python DSL for writing PTX kernels, callable from JAX, PyTorch, and torch.compile

Project description

pyptx

pyptx

Write PTX kernels in Python. Launch them from jax.jit, PyTorch, and torch.compile.

pyptx is a Python DSL for handwritten PTX on NVIDIA Hopper (sm_90a) and Blackwell (sm_100a).

One call = one instruction. No optimizer, no autotuner, no tensor IR between the Python function and the PTX it emits.

  • explicit registers, predicates, barriers, shared memory
  • Hopper: WGMMA, TMA 2D/3D with multicast, mbarriers, cluster launch
  • Blackwell: tcgen05.mma / .ld, TMEM, SMEM descriptors, warp specialization
  • callable from JAX, PyTorch eager, and torch.compile
  • real PTX parser + emitter + transpiler — round-trips 218+ real PTX files byte-identical

Docs: pyptx.dev · Examples: examples/hopper/, examples/blackwell/ · API: pyptx.dev/api


Install

Command What you get
pip install pyptx DSL, parser, emitter, transpiler (no GPU runtime)
pip install 'pyptx[torch]' + PyTorch eager and torch.compile launch path
pip install 'pyptx[jax]' + jax.jit launch path via typed FFI
pip install 'pyptx[all]' + both PyTorch and JAX

Tip: pip install ninja so the PyTorch C++ extension JIT-builds on first launch (drops dispatch overhead from ~34 µs to ~14 µs).

Performance

Blackwell (B200, bf16)

Kernel Shape pyptx cuBLAS best / cuBLAS
GEMM (tcgen05.mma, 4-stage pipeline, 1SM) 8192³ 1240 TFLOPS 1610 77%
GEMM (1SM) 4096³ 1194 TFLOPS 1532 78%
GEMM 2SM (cta_group::2, 5-stage) 2048³ 649 TFLOPS (beats 1SM) 1006 64%
Grouped GEMM (tcgen05, MoE) G=4 M=2048 N=256 K=2048 401 TFLOPS torch ref ~10.0×
RMS norm / Layer norm / SwiGLU maintained Blackwell ports benchmarked torch ref see kernel suite

Hopper (H100 SXM5, bf16 / f32)

Kernel Shape pyptx vs reference
GEMM (wgmma, warp-specialized) 8192³ 815 TFLOPS beats cuBLAS ≥ 6K
Grouped GEMM (bf16→f32) G=8 M=K=2048 104 TFLOPS
RMS norm (f32) B=2048 N=8192 2.6 TB/s (88% HBM) 3.9× torch
Layer norm (f32) B=2048 N=8192 2.5 TB/s (83% HBM) 1.5× F.layer_norm
SwiGLU (f32) M=2048 F=8192 2.8 TB/s (94% HBM) 1.6× F.silu(g)*u
Flash attention (bf16) M=N=4096, HD=64 88 µs 3.0× naive torch

Full benchmark tables + reproduction commands: pyptx.dev/performance.

PyTorch dispatch tiers:

  • CUDA graph replay: ~4 µs per launch
  • Turbo eager: ~14 µs (cached C++ extension)
  • torch.compile: ~14–22 µs (custom_op path)

What it looks like

from pyptx import kernel, reg, smem, ptx, Tile
from pyptx.types import bf16, f32

@kernel(
    in_specs=(Tile("M", "K", bf16), Tile("K", "N", bf16)),
    out_specs=(Tile("M", "N", f32),),
    grid=lambda M, N, K: (N // 64, M // 64),
    block=(128, 1, 1),
    arch="sm_90a",
)
def gemm(A, B, C):
    sA = smem.wgmma_tile(bf16, (64, 16), major="K")
    sB = smem.wgmma_tile(bf16, (16, 64), major="MN")
    acc = reg.array(f32, 32)
    # ... TMA loads + ptx.wgmma.mma_async(...) — each call emits exactly one PTX instruction

Every ptx.* call is a single PTX instruction. print(gemm.ptx()) shows exactly what you wrote.

One kernel, three runtime paths

The same kernel object works in JAX, PyTorch eager, and torch.compile:

# PyTorch eager
out = gemm(a, b)

# torch.compile
out = torch.compile(gemm)(a, b)

# JAX jit (lowers through typed FFI)
out = jax.jit(gemm)(a, b)

Under the hood the PTX is JITed through cuModuleLoadData, registered with a ~150-line C++ launch shim, and dispatched from PyTorch via torch.library.custom_op or from JAX via jax.ffi.ffi_call.


Transpile existing PTX into pyptx

pyptx is also a real PTX-to-Python transpiler. Feed it output from nvcc, Triton, Pallas, or any other source:

python -m pyptx.codegen kernel.ptx --sugar --name my_kernel > my_kernel.py

--sugar demangles names, raises spin-loops into ptx.loop(...), collapses mbarrier-wait blocks, and groups expression chains. Round-trips are byte-identical on 218+ corpus files (CUTLASS, Triton, fast.cu, DeepGEMM, ThunderKittens, LLVM tests).

The 815 TFLOPS Hopper GEMM in examples/hopper/gemm_highperf_hopper.py is exactly this workflow applied to fast.cu's kernel12.


Start here

Hopper (sm_90a):

  • examples/hopper/rms_norm.py — simplest real kernel, v4 loads + warp reduce
  • examples/hopper/grouped_gemm.py — multi-k WGMMA for MoE shapes
  • examples/hopper/gemm_highperf_hopper.py — warp-specialized 815 TFLOPS GEMM

Blackwell (sm_100a):

  • examples/blackwell/tcgen05_suite.py — 13 isolated tcgen05 primitives (alloc, MMA, ld, commit/fence, GEMM probes). Run this first on a B200 to verify the runtime stack.
  • examples/blackwell/gemm_highperf_blackwell.pybuild_gemm (1SM, 4-stage ring buffer, 1.24 PFLOPS at 8192³ bf16) and build_gemm_2sm (2SM cta_group::2 cooperative MMA, 5-stage).
  • examples/blackwell/gemm_experimental_blackwell.py — persistent and Pallas-style experimental GEMM paths, plus the no-TMA tcgen05 debug GEMM.
  • examples/blackwell/grouped_gemm.py — G-problem MoE grouped GEMM on top of the same tcgen05.mma mainloop, bit-exact against einsum("gmk,gkn->gmn") through G=8 M=1024 N=128 K=1024.
  • examples/blackwell/rms_norm.py / layer_norm.py / swiglu.py — Hopper kernels re-targeted to sm_100a.
  • benchmarks/bench_blackwell_gemm.py — reproduce the 1SM + 2SM + cuBLAS table above.
  • benchmarks/bench_blackwell_kernels.py — Blackwell grouped GEMM, RMSNorm, LayerNorm, and SwiGLU benchmark suite.

Docs:

Status

0.1.0, pre-launch. Scope:

  • handwritten PTX DSL with full Hopper ISA (wgmma, TMA 2D/3D, mbarriers, cluster)
  • Blackwell tcgen05 ISA (alloc, mma.kind::f16/tf32/f8, ld/st, commit, fence) with instruction-descriptor + SMEM-descriptor helpers
  • PTX parser / emitter with 218+ corpus round-trip tests
  • PTX → Python transpiler with sugar pass
  • JAX runtime integration (typed FFI)
  • PyTorch eager + torch.compile + CUDA graph replay
  • C++ dispatch extension for low-overhead launches
  • GMMA/UMMA SMEM swizzle helpers (B32 / B64 / B128, CuTe-compatible Swizzle<B,4,3>)
  • PyTorch autograd via differentiable_kernel

License

Apache-2.0. See LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pyptx-0.1.0.tar.gz (863.1 kB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

pyptx-0.1.0-py3-none-any.whl (244.3 kB view details)

Uploaded Python 3

pyptx-0.1.0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (279.6 kB view details)

Uploaded CPython 3.13manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

pyptx-0.1.0-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl (277.2 kB view details)

Uploaded CPython 3.13manylinux: glibc 2.24+ ARM64manylinux: glibc 2.28+ ARM64

pyptx-0.1.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (279.6 kB view details)

Uploaded CPython 3.12manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

pyptx-0.1.0-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl (277.2 kB view details)

Uploaded CPython 3.12manylinux: glibc 2.24+ ARM64manylinux: glibc 2.28+ ARM64

pyptx-0.1.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (279.6 kB view details)

Uploaded CPython 3.11manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

pyptx-0.1.0-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl (277.2 kB view details)

Uploaded CPython 3.11manylinux: glibc 2.24+ ARM64manylinux: glibc 2.28+ ARM64

pyptx-0.1.0-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (279.3 kB view details)

Uploaded CPython 3.10manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

pyptx-0.1.0-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl (277.1 kB view details)

Uploaded CPython 3.10manylinux: glibc 2.24+ ARM64manylinux: glibc 2.28+ ARM64

File details

Details for the file pyptx-0.1.0.tar.gz.

File metadata

  • Download URL: pyptx-0.1.0.tar.gz
  • Upload date:
  • Size: 863.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for pyptx-0.1.0.tar.gz
Algorithm Hash digest
SHA256 9a8adddcc83a23af23e6d71c6f392579e2e1d140a000977295884ee889277c26
MD5 b40ffb6ef78bceb0d528c3b42cd5fdf5
BLAKE2b-256 c0553dcfb7ffafec5af8afc8b52948a960df3b0d1ca105ee763d205d701c4b70

See more details on using hashes here.

Provenance

The following attestation bundles were made for pyptx-0.1.0.tar.gz:

Publisher: build-wheels.yml on patrick-toulme/pyptx

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file pyptx-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: pyptx-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 244.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for pyptx-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 5a1964ec865489cad2c5a2116c3d7ec2009092841b975fb1d2e68cb3fbd04acc
MD5 b71eed126b27eb72ae2054913148c43b
BLAKE2b-256 b90045b2852e9dd6263c857d17455b0517b82a56771d72f23117aac697ca080c

See more details on using hashes here.

Provenance

The following attestation bundles were made for pyptx-0.1.0-py3-none-any.whl:

Publisher: build-wheels.yml on patrick-toulme/pyptx

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file pyptx-0.1.0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for pyptx-0.1.0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 c42c3bb8a778f325d4170a069fb108c6b0b747dc90fa80567978ec3c2259c0f6
MD5 927532ab220f439acf3bd0a7f968b533
BLAKE2b-256 5daf98b820bc1eacf2ef69f4a7255dae18335491be6a3fb777c6ff0c6d5a0fb1

See more details on using hashes here.

Provenance

The following attestation bundles were made for pyptx-0.1.0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl:

Publisher: build-wheels.yml on patrick-toulme/pyptx

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file pyptx-0.1.0-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl.

File metadata

File hashes

Hashes for pyptx-0.1.0-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl
Algorithm Hash digest
SHA256 6888f22a72f08224e5c4ea889bd67174115dcb6ae7bd69a91cedd8d502d9bd51
MD5 4679c4f8174e6ee0c0349669d4a3876c
BLAKE2b-256 a43410552a59bacc1ab25cc21e2dc2246ea5a96fd4b25afcd0e35f3909997e3c

See more details on using hashes here.

Provenance

The following attestation bundles were made for pyptx-0.1.0-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl:

Publisher: build-wheels.yml on patrick-toulme/pyptx

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file pyptx-0.1.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for pyptx-0.1.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 f6ffd569e0d4afe8ae4044a98a2f4472a969eb6692ba04cb74b9b5e6cda9c499
MD5 306f20d947a1dcfbaf9021c9f6a21cbc
BLAKE2b-256 7a57b3ed5c97203d9fa581b43344e4abc568108b74a8e8b9963e68ca26b1a640

See more details on using hashes here.

Provenance

The following attestation bundles were made for pyptx-0.1.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl:

Publisher: build-wheels.yml on patrick-toulme/pyptx

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file pyptx-0.1.0-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl.

File metadata

File hashes

Hashes for pyptx-0.1.0-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl
Algorithm Hash digest
SHA256 6013ae04cfcb5784639d959b457423121458d3b669d041c5adfbd56a0449ce46
MD5 79e7802dcba4d59958df43e4e95cc148
BLAKE2b-256 ae978e1c17c79dcf504a4eb83a2a43eebf8c554c6194696b0f7120b6b0f09f5e

See more details on using hashes here.

Provenance

The following attestation bundles were made for pyptx-0.1.0-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl:

Publisher: build-wheels.yml on patrick-toulme/pyptx

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file pyptx-0.1.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for pyptx-0.1.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 06cc881999fd762a88301e3d61fa106f3aa4dee98edcddd560774dd99abdfa23
MD5 1b0707ad8f77ed204b12a8846c8579ec
BLAKE2b-256 40fd91b9116be6dac50c29c3726874ef1f38e30742d73b4978e1653f8f627fa3

See more details on using hashes here.

Provenance

The following attestation bundles were made for pyptx-0.1.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl:

Publisher: build-wheels.yml on patrick-toulme/pyptx

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file pyptx-0.1.0-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl.

File metadata

File hashes

Hashes for pyptx-0.1.0-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl
Algorithm Hash digest
SHA256 d5d19fa0df69d773f9b58b3ef7760d1d26bbf590ae405f3c754dc8f2bb9dea67
MD5 edf6e16dbcc80f3a4a7653fbef8e49f9
BLAKE2b-256 347cd64a060aacc734ceab26a0dd9e35fa949b9f62e25dd65c81c7f7960320f5

See more details on using hashes here.

Provenance

The following attestation bundles were made for pyptx-0.1.0-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl:

Publisher: build-wheels.yml on patrick-toulme/pyptx

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file pyptx-0.1.0-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for pyptx-0.1.0-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 160e0a96ac313f9d1cdb998255da69ceb6b5ffd2da9b96af8d92de24c954be36
MD5 665a14f2f5a8a88ca22eb01f1ce23e72
BLAKE2b-256 5f9cc55f37a4b273529f61da46cf704ddfdf596482cd35d10413d6d74983ed62

See more details on using hashes here.

Provenance

The following attestation bundles were made for pyptx-0.1.0-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl:

Publisher: build-wheels.yml on patrick-toulme/pyptx

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file pyptx-0.1.0-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl.

File metadata

File hashes

Hashes for pyptx-0.1.0-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl
Algorithm Hash digest
SHA256 305be74ce627aed98f9a43ea899f2a4bbb4a7a55731e5b3eb3eca6cf427a411f
MD5 16e4f5beeb96fb068f562c948ef0bd0d
BLAKE2b-256 353aea7eaf56c4a4f4a330ab1335552a5d8516f5af15e57bb48001ff4a528d98

See more details on using hashes here.

Provenance

The following attestation bundles were made for pyptx-0.1.0-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl:

Publisher: build-wheels.yml on patrick-toulme/pyptx

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page