Skip to main content

Analytical GEMM Solution Selection

Project description

Origami

PyPI License: MIT Python 3.9+

Origami is a fast, analytical, deterministic model for selecting optimal GEMM (General Matrix Multiply) kernel configurations on AMD GPUs. Instead of expensive autotuning, Origami analytically estimates kernel performance by modeling compute and memory latencies to select the best tile size, occupancy, and workgroup mapping for your problem.

Installation

Requires: ROCm/HIP installed on your system.

pip install rocm-origami

Quick Start

Predict the Best GEMM Configuration

Given a set of candidate tile configurations, Origami selects the one with the lowest predicted latency:

import origami

# Detect your GPU
hardware = origami.get_hardware_for_device(0)

# Define the GEMM problem: D = A^T @ B, where A and B are 4096x4096 FP16 matrices
problem = origami.problem_t()
problem.size = origami.dim3_t(4096, 4096, 4096)  # M, N, K
problem.batch = 1
problem.a_transpose = origami.transpose_t.T
problem.b_transpose = origami.transpose_t.N
problem.a_dtype = origami.data_type_t.Half
problem.b_dtype = origami.data_type_t.Half
problem.c_dtype = origami.data_type_t.Half
problem.d_dtype = origami.data_type_t.Half
problem.mi_dtype = origami.data_type_t.Half

# Define candidate configurations (tile sizes + matrix instruction + occupancy)
configs = []
for mt_m, mt_n, mt_k in [(128, 128, 32), (256, 128, 64), (64, 64, 128)]:
    cfg = origami.config_t()
    cfg.mt = origami.dim3_t(mt_m, mt_n, mt_k)
    cfg.mi = origami.dim3_t(16, 16, 16)
    cfg.occupancy = 2
    configs.append(cfg)

# Select the best configuration
best = origami.select_config(problem, hardware, configs)
print(f"Best tile: {best.config.mt.m}x{best.config.mt.n}x{best.config.mt.k}")
print(f"Predicted latency: {best.latency:.2f}")

Get Latency Predictions for All Configurations

Rank all candidate configurations by predicted performance:

import origami

hardware = origami.get_hardware_for_device(0)

problem = origami.problem_t()
problem.size = origami.dim3_t(8192, 8192, 8192)
problem.batch = 1
problem.a_transpose = origami.transpose_t.T
problem.b_transpose = origami.transpose_t.N
problem.a_dtype = origami.data_type_t.Half
problem.b_dtype = origami.data_type_t.Half
problem.c_dtype = origami.data_type_t.Half
problem.d_dtype = origami.data_type_t.Half
problem.mi_dtype = origami.data_type_t.Half

# Build a config list
configs = []
for mt_m, mt_n, mt_k, occ in [
    (128, 128, 32, 2), (256, 128, 64, 1), (64, 64, 128, 2),
    (256, 256, 64, 1), (128, 256, 32, 2), (64, 128, 64, 2),
]:
    cfg = origami.config_t()
    cfg.mt = origami.dim3_t(mt_m, mt_n, mt_k)
    cfg.mi = origami.dim3_t(16, 16, 16)
    cfg.occupancy = occ
    configs.append(cfg)

# Rank all configs by latency (best first)
ranked = origami.rank_configs(problem, hardware, configs)

for i, result in enumerate(ranked):
    gflops = origami.compute_perf_gflops(hardware, problem, result.latency)
    mt = result.config.mt
    print(f"  #{i+1}: {mt.m:>3}x{mt.n:>3}x{mt.k:>3}  "
          f"latency={result.latency:>10.2f}  {gflops:>8.1f} GFLOPS")

Use with Triton (PyTorch)

The OrigamiMatmulSelector replaces Triton's autotuner for GEMM kernels:

import torch
from origami import OrigamiMatmulSelector

# Define Triton-style candidate configs
configs = [
    MockConfig(128, 128, 32, 2),
    MockConfig(256, 128, 64, 1),
    MockConfig(64, 64, 128, 2),
]

selector = OrigamiMatmulSelector(
    config_gen=configs,
    m=4096, n=4096, k=4096,
    a_dtype=torch.float16,
    b_dtype=torch.float16,
    out_dtype=torch.float16,
    device=torch.device("cuda:0"),
)

# Use the selected parameters in your Triton kernel launch
print(f"BLOCK_M={selector.macrotile_m}")
print(f"BLOCK_N={selector.macrotile_n}")
print(f"BLOCK_K={selector.macrotile_k}")
print(f"waves_per_eu={selector.occupancy}")

Supported GPUs

LLVM Target GPUs Functional Optimized
gfx942 MI325X, MI300X, MI300A Yes Yes
gfx950 MI355X, MI350X Yes Yes
gfx1100 Radeon RX 7900 XTX/XT/GRE Yes
gfx1150 AMD Strix Point iGPU Yes
gfx1151 Radeon RX 8000 series Yes
gfx1152 AMD Radeon 840M iGPU Yes
gfx1153 AMD Radeon 820M iGPU Yes
gfx1201 Radeon RX 8900/8800/8700/8600 series Yes

API Reference

Core Functions

Function Description
select_config(problem, hardware, configs) Select the best configuration for a GEMM problem
rank_configs(problem, hardware, configs) Rank all configurations by predicted latency
select_topk_configs(problem, hardware, configs, k) Return the top-K configurations
select_config_mnk(M, N, K, hardware, configs) Shorthand selection using just M, N, K
compute_total_latency(problem, hardware, config, n_cu) Compute predicted latency for a single config
compute_perf_gflops(hardware, problem, latency) Convert latency to GFLOPS throughput

Hardware Detection

Function Description
get_hardware_for_device(device_id) Auto-detect GPU hardware from device index
get_hardware_for_arch(arch, n_cu, lds, l2, clock) Create hardware descriptor for a specific architecture

Types

Type Description
problem_t GEMM problem specification (dimensions, dtypes, transpose)
config_t Kernel configuration (tile size, matrix instruction, occupancy)
prediction_result_t Result containing a config and its predicted latency
hardware_t GPU hardware characteristics
dim3_t 3D dimension type (m, n, k)
data_type_t Data type enum (Half, BFloat16, Float, Int8, etc.)
transpose_t Transpose enum (T, N)

Contributing

File issues and pull requests on GitHub.

Citation

@misc{Swann:2025:TTB,
  title={{tritonBLAS}: Triton-based Analytical Approach for GEMM Kernel Parameter Selection},
  author={Ryan Swann and Muhammad Osama and Xiaohu Guo and Bryant Nelson and Lixun Zhang and Alex Brown and Yen Ong and Ali Yazdani and Sean Siddens and Ganesh Dasika and Alex Underwood},
  year={2025},
  eprint={2512.04226},
  archivePrefix={arXiv},
  primaryClass={cs.DC},
  url={https://arxiv.org/abs/2512.04226},
}

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rocm_origami-0.0.2.tar.gz (180.6 kB view details)

Uploaded Source

File details

Details for the file rocm_origami-0.0.2.tar.gz.

File metadata

  • Download URL: rocm_origami-0.0.2.tar.gz
  • Upload date:
  • Size: 180.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.3

File hashes

Hashes for rocm_origami-0.0.2.tar.gz
Algorithm Hash digest
SHA256 2fa95f367b8373b6e594c67ccbd6ec0472fcfffa68313fa1db70be068959b99b
MD5 4056cd695d085ef0257ece3de824324e
BLAKE2b-256 862d750b6ae3458f28264a92e8eb8fe1201fedcea94c8296f59251f91b7746e1

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page