Skip to main content

Python DSL for writing PTX kernels, callable from JAX, PyTorch, and torch.compile

Project description

pyptx

pyptx

Write PTX kernels in Python. Launch them from jax.jit, PyTorch, and torch.compile.

pyptx is a Python DSL for handwritten PTX on NVIDIA Ampere (sm_80), Ada (sm_89), Hopper (sm_90a), Blackwell datacenter (sm_100a), and Blackwell workstation (sm_120). Pre-Ampere targets like Turing (sm_75, T4) work for kernels that stay within the sm_75 ISA — anything using cp.async, mbarrier, bf16, wgmma, tcgen05, or TMA needs an Ampere-or-newer card.

One call = one instruction. No optimizer, no autotuner, no tensor IR between the Python function and the PTX it emits.

  • explicit registers, predicates, barriers, shared memory
  • Ampere: mma.sync (m16n8k{8,16,32}), cp.async, ldmatrix, SMEM staging
  • Hopper: WGMMA, TMA 2D/3D with multicast, mbarriers, cluster launch
  • Blackwell: tcgen05.mma / .ld, TMEM, SMEM descriptors, warp specialization
  • callable from JAX, PyTorch eager, and torch.compile
  • arch="auto" picks the right target for the current GPU at trace time (validated on T4, A100, L4, H100, B200, RTX Pro 6000 Blackwell)
  • real PTX parser + emitter + transpiler — round-trips 218+ real PTX files byte-identical

Docs: pyptx.dev · Examples: examples/ampere/, examples/hopper/, examples/blackwell/ · API: pyptx.dev/api


Install

Command What you get
pip install pyptx DSL, parser, emitter, transpiler (no GPU runtime)
pip install 'pyptx[torch]' + PyTorch eager and torch.compile launch path
pip install 'pyptx[jax]' + jax.jit launch path via typed FFI
pip install 'pyptx[all]' + both PyTorch and JAX

Tip: pip install ninja so the PyTorch C++ extension JIT-builds on first launch (drops dispatch overhead from ~34 µs to ~14 µs).

Performance

Blackwell (B200, bf16)

Kernel Shape pyptx cuBLAS best / cuBLAS
GEMM (tcgen05.mma, 4-stage pipeline, 1SM) 8192³ 1240 TFLOPS 1610 77%
GEMM (1SM) 4096³ 1194 TFLOPS 1532 78%
GEMM 2SM (cta_group::2, 5-stage) 2048³ 649 TFLOPS (beats 1SM) 1006 64%
Grouped GEMM (tcgen05, MoE) G=4 M=2048 N=256 K=2048 401 TFLOPS torch ref ~10.0×
RMS norm / Layer norm / SwiGLU maintained Blackwell ports benchmarked torch ref see kernel suite

Hopper (H100 SXM5, bf16 / f32)

Kernel Shape pyptx vs reference
GEMM (wgmma, warp-specialized) 8192³ 815 TFLOPS beats cuBLAS ≥ 6K
Grouped GEMM (bf16→f32) G=8 M=K=2048 104 TFLOPS
RMS norm (f32) B=2048 N=8192 2.6 TB/s (88% HBM) 3.9× torch
Layer norm (f32) B=2048 N=8192 2.5 TB/s (83% HBM) 1.5× F.layer_norm
SwiGLU (f32) M=2048 F=8192 2.8 TB/s (94% HBM) 1.6× F.silu(g)*u
Softmax (f32, row-wise) B=2048 N=8192 2.8 TB/s (95% HBM) 1.16× torch.softmax
Flash attention (bf16) M=N=4096, HD=64 88 µs 3.0× naive torch

Ampere (A100 80GB, bf16 / f32)

Kernel Shape pyptx vs reference
GEMM (ldmatrix.x4 + cp.async 4-stage + register frag double-buffer + XOR swizzle + serpentine mma.sync) 4096³ bf16 162 TFLOPS cuBLAS 223 TFLOPS (73%)
GEMM (same kernel) 2048³ bf16 108 TFLOPS cuBLAS 158 TFLOPS (68%)
GEMM (simple mma.sync + 2-stage pipeline, teaching kernel) 4096³ bf16 64 TFLOPS cuBLAS 230 TFLOPS (28%)
RMS norm (f32) B=2048 N=8192 928 GB/s 2.2× torch
SwiGLU (f32) M=2048 F=8192 1.33 TB/s 1.35× F.silu(g)*u
Layer norm (f32) B=2048 N=8192 916 GB/s 0.89× F.layer_norm (torch's fused kernel is hard to beat)

A100 numbers reproduce via python benchmarks/bench_ampere_kernels.py. The high-perf A100 GEMM follows the CUTLASS SM80 / MatmulTutorial v15 design pattern: 128×128×32 CTA tile, 4 warps in 2×2 owning 64×64 output sub-tiles each, warp-collective ldmatrix.x4 for SMEM→register fragment loads, 4-stage cp.async ring buffer (3 in-flight), register fragment double-buffering that pre-loads the next K-iter's first K-block during the current iter's last mma, CUTLASS XOR swizzle (atom ^= row & 3) on all SMEM paths to eliminate ldmatrix bank conflicts, serpentine N-fragment order for adjacent-mma operand reuse, and per-thread offset hoisting so each inner-loop ldmatrix is one add instead of 5+ ops. 64 mma.sync.m16n8k16 per warp per K-iter (256 per CTA per K-iter). We haven't spent much time tuning this kernel — the 27% remaining gap is addressable (persistent / stream-K scheduling, more aggressive instruction-level overlap, autotuned tile sizes). See examples/ampere/gemm_highperf_ampere.py for the full kernel.

Full benchmark tables + reproduction commands: pyptx.dev/performance.

PyTorch dispatch tiers:

  • CUDA graph replay: ~4 µs per launch
  • Turbo eager: ~14 µs (cached C++ extension)
  • torch.compile: ~14–22 µs (custom_op path)

What it looks like

from pyptx import kernel, reg, smem, ptx, Tile
from pyptx.types import bf16, f32

@kernel(
    in_specs=(Tile("M", "K", bf16), Tile("K", "N", bf16)),
    out_specs=(Tile("M", "N", f32),),
    grid=lambda M, N, K: (N // 64, M // 64),
    block=(128, 1, 1),
    arch="sm_90a",
)
def gemm(A, B, C):
    sA = smem.wgmma_tile(bf16, (64, 16), major="K")
    sB = smem.wgmma_tile(bf16, (16, 64), major="MN")
    acc = reg.array(f32, 32)
    # ... TMA loads + ptx.wgmma.mma_async(...) — each call emits exactly one PTX instruction

Every ptx.* call is a single PTX instruction. print(gemm.ptx()) shows exactly what you wrote.

One kernel, three runtime paths

The same kernel object works in JAX, PyTorch eager, and torch.compile:

# PyTorch eager
out = gemm(a, b)

# torch.compile
out = torch.compile(gemm)(a, b)

# JAX jit (lowers through typed FFI)
out = jax.jit(gemm)(a, b)

Under the hood the PTX is JITed through cuModuleLoadData, registered with a ~150-line C++ launch shim, and dispatched from PyTorch via torch.library.custom_op or from JAX via jax.ffi.ffi_call.


Transpile existing PTX into pyptx

pyptx is also a real PTX-to-Python transpiler. Feed it output from nvcc, Triton, Pallas, or any other source:

python -m pyptx.codegen kernel.ptx --sugar --name my_kernel > my_kernel.py

--sugar demangles names, raises spin-loops into ptx.loop(...), collapses mbarrier-wait blocks, and groups expression chains. Round-trips are byte-identical on 218+ corpus files (CUTLASS, Triton, fast.cu, DeepGEMM, ThunderKittens, LLVM tests).

The 815 TFLOPS Hopper GEMM in examples/hopper/gemm_highperf_hopper.py is exactly this workflow applied to fast.cu's kernel12.


Start here

Ampere (sm_80):

  • examples/ampere/rms_norm.py / layer_norm.py / swiglu.py / softmax.py — maintained Hopper kernels retargeted to sm_80.
  • examples/ampere/gemm.py — single-warp mma.sync.aligned.m16n8k16 bf16 GEMM, no SMEM staging. The minimal end-to-end Ampere tensor-core path.
  • examples/ampere/gemm_pipelined.pycp.async 2-stage SMEM ring buffer
    • mma.sync on a 64×64 CTA tile (per-thread ld.shared, no ldmatrix). The first-step pipelined kernel (~64 TFLOPS at 4096³).
  • examples/ampere/gemm_highperf_ampere.py — production-leaning A100 GEMM following CUTLASS SM80 + MatmulTutorial v15. 128×128×32 CTA tile, 4 warps in 2×2 owning 64×64 each, ldmatrix.x4, 4-stage cp.async pipeline, register frag double-buffering across K-iters, XOR swizzle + serpentine mma, 64 mma.sync per warp per K-iter. 162 TFLOPS at 4096³ bf16 = 73% of cuBLAS (2.5× the simpler gemm_pipelined.py). Bit-exact through 4096³.
  • benchmarks/bench_ampere_kernels.py — A100 RMSNorm, LayerNorm, SwiGLU, and GEMM benchmark suite.

Hopper (sm_90a):

  • examples/hopper/rms_norm.py — simplest real kernel, v4 loads + warp reduce
  • examples/hopper/grouped_gemm.py — multi-k WGMMA for MoE shapes
  • examples/hopper/gemm_highperf_hopper.py — warp-specialized 815 TFLOPS GEMM

Blackwell (sm_100a):

  • examples/blackwell/tcgen05_suite.py — 13 isolated tcgen05 primitives (alloc, MMA, ld, commit/fence, GEMM probes). Run this first on a B200 to verify the runtime stack.
  • examples/blackwell/gemm_highperf_blackwell.pybuild_gemm (1SM, 4-stage ring buffer, 1.24 PFLOPS at 8192³ bf16) and build_gemm_2sm (2SM cta_group::2 cooperative MMA, 5-stage).
  • examples/blackwell/gemm_experimental_blackwell.py — persistent and Pallas-style experimental GEMM paths, plus the no-TMA tcgen05 debug GEMM.
  • examples/blackwell/grouped_gemm.py — G-problem MoE grouped GEMM on top of the same tcgen05.mma mainloop, bit-exact against einsum("gmk,gkn->gmn") through G=8 M=1024 N=128 K=1024.
  • examples/blackwell/rms_norm.py / layer_norm.py / swiglu.py — Hopper kernels re-targeted to sm_100a.
  • benchmarks/bench_blackwell_gemm.py — reproduce the 1SM + 2SM + cuBLAS table above.
  • benchmarks/bench_blackwell_kernels.py — Blackwell grouped GEMM, RMSNorm, LayerNorm, and SwiGLU benchmark suite.

Docs:

Status

0.1.0, pre-launch. Scope:

  • handwritten PTX DSL with full Hopper ISA (wgmma, TMA 2D/3D, mbarriers, cluster)
  • Blackwell tcgen05 ISA (alloc, mma.kind::f16/tf32/f8, ld/st, commit, fence) with instruction-descriptor + SMEM-descriptor helpers
  • PTX parser / emitter with 218+ corpus round-trip tests
  • PTX → Python transpiler with sugar pass
  • JAX runtime integration (typed FFI)
  • PyTorch eager + torch.compile + CUDA graph replay
  • C++ dispatch extension for low-overhead launches
  • GMMA/UMMA SMEM swizzle helpers (B32 / B64 / B128, CuTe-compatible Swizzle<B,4,3>)
  • PyTorch autograd via differentiable_kernel

License

Apache-2.0. See LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pyptx-0.1.1.tar.gz (962.5 kB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

pyptx-0.1.1-py3-none-any.whl (271.6 kB view details)

Uploaded Python 3

pyptx-0.1.1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (307.0 kB view details)

Uploaded CPython 3.13manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

pyptx-0.1.1-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl (304.6 kB view details)

Uploaded CPython 3.13manylinux: glibc 2.24+ ARM64manylinux: glibc 2.28+ ARM64

pyptx-0.1.1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (307.0 kB view details)

Uploaded CPython 3.12manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

pyptx-0.1.1-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl (304.6 kB view details)

Uploaded CPython 3.12manylinux: glibc 2.24+ ARM64manylinux: glibc 2.28+ ARM64

pyptx-0.1.1-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (307.0 kB view details)

Uploaded CPython 3.11manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

pyptx-0.1.1-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl (304.6 kB view details)

Uploaded CPython 3.11manylinux: glibc 2.24+ ARM64manylinux: glibc 2.28+ ARM64

pyptx-0.1.1-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (306.7 kB view details)

Uploaded CPython 3.10manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

pyptx-0.1.1-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl (304.5 kB view details)

Uploaded CPython 3.10manylinux: glibc 2.24+ ARM64manylinux: glibc 2.28+ ARM64

File details

Details for the file pyptx-0.1.1.tar.gz.

File metadata

  • Download URL: pyptx-0.1.1.tar.gz
  • Upload date:
  • Size: 962.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for pyptx-0.1.1.tar.gz
Algorithm Hash digest
SHA256 860dc33ac6b01893facaa20797ab8b14767add562748c273e221f75802003dca
MD5 52ea1a40846021d3f8f0448d1f7ce39b
BLAKE2b-256 4751d87a27edc90ee380db8a6c7e893ec38507708abc7d3f3eac10774c9897c1

See more details on using hashes here.

Provenance

The following attestation bundles were made for pyptx-0.1.1.tar.gz:

Publisher: build-wheels.yml on patrick-toulme/pyptx

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file pyptx-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: pyptx-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 271.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for pyptx-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 51c625030ad274cf8b2bf440cd6d607e15f3d221642682b0af528b6b353e1639
MD5 63b2e0fc32dec35afbae055ebf16d2d8
BLAKE2b-256 f94e96c58ef5d2e6a13b3fb711384bef25ba0b57f418f3f4a6e159ab0ba8362b

See more details on using hashes here.

Provenance

The following attestation bundles were made for pyptx-0.1.1-py3-none-any.whl:

Publisher: build-wheels.yml on patrick-toulme/pyptx

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file pyptx-0.1.1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for pyptx-0.1.1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 b7efb3c8b0d5e08d11d1749399f5b03747db96c61ed45849b988d7bba0daf8f8
MD5 2921fdc3a3285c61d25a8a78e436dec6
BLAKE2b-256 96edba4ea9ff56723339642d1a9f9f56efbad8329378dc204a646bd20d03c9ac

See more details on using hashes here.

Provenance

The following attestation bundles were made for pyptx-0.1.1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl:

Publisher: build-wheels.yml on patrick-toulme/pyptx

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file pyptx-0.1.1-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl.

File metadata

File hashes

Hashes for pyptx-0.1.1-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl
Algorithm Hash digest
SHA256 16f2aed30b930094ae057bfc3e6d91d3aa9b64a22b7ea91f24cdbfc104448436
MD5 891816dfe7646e6bdffdd72b0fb9bdb0
BLAKE2b-256 8f88d7f9808fceaa73b1cf165a36ed131ec16996b22685dd09c48cbdf09a6452

See more details on using hashes here.

Provenance

The following attestation bundles were made for pyptx-0.1.1-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl:

Publisher: build-wheels.yml on patrick-toulme/pyptx

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file pyptx-0.1.1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for pyptx-0.1.1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 86fac4963ac637752e3852ad71a7045ed1a4d83c49a8e40ff6f8d839a7d87716
MD5 dc4dcd1b2f8b339cd63e0a9199d474f6
BLAKE2b-256 0a837fe9df3b7b8bc6c96d6a571b19a5890b2e9daf85e4970db6bd22959abd9e

See more details on using hashes here.

Provenance

The following attestation bundles were made for pyptx-0.1.1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl:

Publisher: build-wheels.yml on patrick-toulme/pyptx

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file pyptx-0.1.1-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl.

File metadata

File hashes

Hashes for pyptx-0.1.1-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl
Algorithm Hash digest
SHA256 7970988bf0c785108434241846824baa3291eb5d77396de8034811fe9e7e15fd
MD5 6a86602384bdef508290f7f6ff36cd07
BLAKE2b-256 a3df6b703d45f9d0a135cd4d20fc06ab27f21d5c079c63a1154339642318671b

See more details on using hashes here.

Provenance

The following attestation bundles were made for pyptx-0.1.1-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl:

Publisher: build-wheels.yml on patrick-toulme/pyptx

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file pyptx-0.1.1-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for pyptx-0.1.1-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 a9e9565dd0d3f627658977d5326247e5a08906e376c32c764b42253c3aa67da9
MD5 ce94f1c1632d871cd20afaefbcdc8995
BLAKE2b-256 ec696622ef7472a782758eebb99ad6998dcd11037eb10cd01b699aec4cd30c8f

See more details on using hashes here.

Provenance

The following attestation bundles were made for pyptx-0.1.1-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl:

Publisher: build-wheels.yml on patrick-toulme/pyptx

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file pyptx-0.1.1-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl.

File metadata

File hashes

Hashes for pyptx-0.1.1-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl
Algorithm Hash digest
SHA256 93cd67f983470fd13c5f9680e570b82edc5a317fd42354659bdbe4a0def1db06
MD5 dcf632cc4aefc73e100025949ae070ed
BLAKE2b-256 02b934893edd118e432a3b0e222c9b9102b264d425d943a8c710bf962c813e5f

See more details on using hashes here.

Provenance

The following attestation bundles were made for pyptx-0.1.1-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl:

Publisher: build-wheels.yml on patrick-toulme/pyptx

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file pyptx-0.1.1-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for pyptx-0.1.1-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 2d173d12c9695007a12e1e1b2414438c459df917a39e0baaae1b742be62f53b1
MD5 ab876f7d082b1b18e78101b16466145c
BLAKE2b-256 953e8cab1d5cc02edff87b6abd7bf44bc59de459a0903516508ac7fda8a54332

See more details on using hashes here.

Provenance

The following attestation bundles were made for pyptx-0.1.1-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl:

Publisher: build-wheels.yml on patrick-toulme/pyptx

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file pyptx-0.1.1-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl.

File metadata

File hashes

Hashes for pyptx-0.1.1-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl
Algorithm Hash digest
SHA256 91ca30ab99c2a7ea5375a876238ea94238549408629198b15b3f882b68fe791c
MD5 df235206510214bca35d1634c4029420
BLAKE2b-256 d3e2fda55c767e440ff0e57f5a123866da80b56285d6d60ac487455c9a47c877

See more details on using hashes here.

Provenance

The following attestation bundles were made for pyptx-0.1.1-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl:

Publisher: build-wheels.yml on patrick-toulme/pyptx

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page