Skip to main content

Optimized FP4 GPU kernels for AMD GPUs

Project description

Petit

Petit provides optimized FP16/BF16 x FP4 GPU kernels specifically designed for AMD GPUs. It enables efficient execution of NVFP4 and MXFP4 quantized models on GPUs that lack native FP4 arithmetic capabilities. This makes Petit particularly well-suited for serving high-quality FP4 models on standard GPUs while achieving ~3.3x memory savings. For example, a server with 8x AMD MI300x GPUs running sglang v0.4.9.post2 can serve the Llama-3.3-70B-Instruct / Llama-3.3-70B-Instruct-FP4 model with a MMLU score of 82.15 and 80.79 respectively.

Requirements

  • AMD CDNA2 / CDNA3 GPUs (AMD MI2xx / MI3xx series)
  • ROCm 6.2 or later
  • PyTorch 2.5 or later

Installation and Usages

You can install Petit directly using pip:

$ CMAKE_ARGS='-DCMAKE_PREFIX_PATH=/opt/rocm;/usr/local/lib/python3.12/dist-packages/torch' pip install .

You need to specify CMAKE_PREFIX_PATH in CMAKE_ARGS so that cmake can detect the ROCm or PyTorch.

Petit provides python APIs for matrix multiplications that are intended to be integrated with inference frameworks such as SGLang and vLLM. It also provides C++ bindings to enable integrations with frameworks like llama.cpp.

Techniques and Evaluations

Similar to Marlin, Petit performs offline weight shuffling to enable efficient GPU dequantization. To achieve optimal performance, Petit utilizes ranged buffer loads and vector instructions specifically designed for CDNA2 and CDNA3 architectures. These optimizations are based on the assumptions that scales remain positive and quantized weights contain no negative zeros. For detailed information about these optimizations, please refer to the documentation available here.

Petit is optimized for the real-world use cases where the LLM engines perform inferences with small batches. For example, Petit is 1.2x-2.2x faster compared to hipBLASLt when performing BF16 matrix multiplications when batch size less than 16. For larger batches where the performance is bound by the available computational powers, Petit performs within 70% of the hand-optimized hipBLASLt library.

Known Issues

Similar to Marlin, Petit shuffles the data offline to minimize the work performed on the GPU side. It requires all scales are positive which matches the output of the ModelOpt quantizier.

The MFMA instructions on AMD MI2xx GPUs flush input and output denormal values to zero, which can potentially impact numeric accuracy. Petit implements corrective measures for the AMD MI2xx GPUs which have ~10% overheads.

Compared to NVIDIA architectures, CDNA architectures are significantly more sensitive to kernel hyperparameters like the shapes in shared memory. We strongly recommend running auto-tuning to achieve optimal performance. The repository provides benchmarking tool to facilitate auto tunings.

Contacts and Contributions

We thank AMD and InnoMatrix for their generous support of providing access of the GPUs to make this project possible. Neither organization is involved in the development of the project.

Petit is a very young project and we are still working on implementing various optimizations. Please contact haohui@causalflow.ai for questions and supports. Contributions are welcome.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

petit_kernel-0.0.4-cp312-cp312-manylinux_2_39_x86_64.whl (6.3 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.39+ x86-64

petit_kernel-0.0.4-cp310-cp310-manylinux_2_34_x86_64.whl (6.3 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.34+ x86-64

File details

Details for the file petit_kernel-0.0.4-cp312-cp312-manylinux_2_39_x86_64.whl.

File metadata

File hashes

Hashes for petit_kernel-0.0.4-cp312-cp312-manylinux_2_39_x86_64.whl
Algorithm Hash digest
SHA256 008d28cb09cd2b6ed71475628fe4253cf9c2bb4846092d0b52960c8f7109135b
MD5 df5e61a72dcfcd1f1f88f257c7fe4d3c
BLAKE2b-256 294fca19afd047e83d5b4dbbd6e5bc978d1801199f2032170f8faa0dc5811be0

See more details on using hashes here.

File details

Details for the file petit_kernel-0.0.4-cp310-cp310-manylinux_2_34_x86_64.whl.

File metadata

File hashes

Hashes for petit_kernel-0.0.4-cp310-cp310-manylinux_2_34_x86_64.whl
Algorithm Hash digest
SHA256 002202aef2b7a05b5026e934da002fc5faa508b2c0c3e5f4e8477293a148f8e9
MD5 eb05983aa7a50bb8b4c81f6c07e5c3f6
BLAKE2b-256 fee6cb789a585cb8aceb00f42eb3ee1c68cacf7370384c25bb53c2fbdbc14714

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page