Skip to main content

Optimized FP4 GPU kernels for AMD GPUs

Project description

Petit

Petit provides optimized FP16/BF16 x FP4 GPU kernels specifically designed for AMD GPUs. It enables efficient execution of NVFP4 and MXFP4 quantized models on GPUs that lack native FP4 arithmetic capabilities. This makes Petit particularly well-suited for serving high-quality FP4 models on standard GPUs while achieving ~3.3x memory savings. For example, a server with 8x AMD MI300x GPUs running sglang v0.4.9.post2 can serve the Llama-3.3-70B-Instruct / Llama-3.3-70B-Instruct-FP4 model with a MMLU score of 82.15 and 80.79 respectively.

Requirements

  • AMD CDNA2 / CDNA3 GPUs (AMD MI2xx / MI3xx series)
  • ROCm 6.2 or later
  • PyTorch 2.5 or later

Installation and Usages

You can install Petit directly using pip:

$ CMAKE_ARGS='-DCMAKE_PREFIX_PATH=/opt/rocm;/usr/local/lib/python3.12/dist-packages/torch' pip install .

You need to specify CMAKE_PREFIX_PATH in CMAKE_ARGS so that cmake can detect the ROCm or PyTorch.

Petit provides python APIs for matrix multiplications that are intended to be integrated with inference frameworks such as SGLang and vLLM. It also provides C++ bindings to enable integrations with frameworks like llama.cpp.

Techniques and Evaluations

Similar to Marlin, Petit performs offline weight shuffling to enable efficient GPU dequantization. To achieve optimal performance, Petit utilizes ranged buffer loads and vector instructions specifically designed for CDNA2 and CDNA3 architectures. These optimizations are based on the assumptions that scales remain positive and quantized weights contain no negative zeros. For detailed information about these optimizations, please refer to the documentation available here.

Petit is optimized for the real-world use cases where the LLM engines perform inferences with small batches. For example, Petit is 1.2x-2.2x faster compared to hipBLASLt when performing BF16 matrix multiplications when batch size less than 16. For larger batches where the performance is bound by the available computational powers, Petit performs within 70% of the hand-optimized hipBLASLt library.

Known Issues

Similar to Marlin, Petit shuffles the data offline to minimize the work performed on the GPU side. It requires all scales are positive which matches the output of the ModelOpt quantizier.

The MFMA instructions on AMD MI2xx GPUs flush input and output denormal values to zero, which can potentially impact numeric accuracy. Petit implements corrective measures for the AMD MI2xx GPUs which have ~10% overheads.

Compared to NVIDIA architectures, CDNA architectures are significantly more sensitive to kernel hyperparameters like the shapes in shared memory. We strongly recommend running auto-tuning to achieve optimal performance. The repository provides benchmarking tool to facilitate auto tunings.

Contacts and Contributions

We thank AMD and InnoMatrix for their generous support of providing access of the GPUs to make this project possible. Neither organization is involved in the development of the project.

Petit is a very young project and we are still working on implementing various optimizations. Please contact haohui@causalflow.ai for questions and supports. Contributions are welcome.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

petit_kernel-0.0.3-cp312-cp312-manylinux_2_34_x86_64.whl (4.5 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.34+ x86-64

petit_kernel-0.0.3-cp310-cp310-manylinux_2_34_x86_64.whl (4.4 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.34+ x86-64

File details

Details for the file petit_kernel-0.0.3-cp312-cp312-manylinux_2_34_x86_64.whl.

File metadata

File hashes

Hashes for petit_kernel-0.0.3-cp312-cp312-manylinux_2_34_x86_64.whl
Algorithm Hash digest
SHA256 bf3b7a59ff1a711dde8e347f9a52b691af7d75def087f6fda221c3884d6226d0
MD5 85a3108137df0a7e48b0502668ebe55e
BLAKE2b-256 2529174e5447d29e6d9122489d040a1ec2813e49b9c059e4de63b82ab5cc59ad

See more details on using hashes here.

File details

Details for the file petit_kernel-0.0.3-cp310-cp310-manylinux_2_34_x86_64.whl.

File metadata

File hashes

Hashes for petit_kernel-0.0.3-cp310-cp310-manylinux_2_34_x86_64.whl
Algorithm Hash digest
SHA256 6543189a608d955fef34d5008907927339ddb4b8138c070dfe92dafccd674e42
MD5 1bb9578ff2ed4933ae3be849469985ec
BLAKE2b-256 724502942b4df3bd7d61f72a69a91ed2db45bf59489d92adb3f80b3eb91852e0

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page