Skip to main content

PyTorch Quantization Framework For OCP MX Datatypes.

Project description

TorchMX: PyTorch Quantization Framework For OCP MX Datatypes

BUILD Docs

This package a simulation tool implementing the MX quantization format specified in the OCP Micro Scaling Formats. The pacakage includes:

  • Tensor subclasses for representing MX quantized data MXTensor.
  • Quantization and dequantization functions for converting between high-precision and MX quantized tensors.
  • Support for various MX data types, including FP8, FP6, FP4, and INT8.
  • Custom ATen operations for MXTensor

Installation

pip install torchmx

Usage

Here's a basic example of how to quantize a PyTorch tensor to MX format:

import torch
from torchmx import MXTensor, dtypes

# Create a high-precision tensor
x_hp = torch.randn(128, 128, dtype=torch.bfloat16)

# Quantize the tensor to MX format
x_mx = MXTensor.to_mx(x_hp, elem_dtype=dtypes.float8_e4m3, block_size=32)

# Dequantize the tensor back to high-precision
x_hp_reconstructed = x_mx.to_dtype(torch.bfloat16)

# Matmul 2 MXTensors
y_hp = torch.randn(128, 128, dtype=torch.bfloat16)
y_mx = MXTensor.to_mx(y_mx, elem_dtype=dtypes.float6_e3m2, block_size=32)

# Notice the magic here. You can pass MXTensors into torch.matmul.
# This even works for 4D Attention Matmuls torch.matmul(Q, K.t).
# The output is a bf16 torch.Tensor
out_bf16 = torch.matmul(x_mx, y_mx)

Quantizing Layers and Modules

TorchMX also provides tools for quantizing individual layers and modules. Here's an example of how to quantize all the linear layers in the model. The following example demonstrates how to quantize a model with torch.nn.Linear layers to MX format using the MXInferenceLinear class. In this case the weights are quantized MX-fp6_e3m2 and the inputs to MX-fp8_e4m3

from torch import nn
from torchmx import MXTensor, dtypes
from torchmx.config import QLinearConfig, MXConfig
from torchmx.quant_api import quantize_linear_

# Create a high-precision model
model = nn.Sequential(
          nn.Linear(128, 256),
          nn.ReLU(),
          nn.Linear(256, 64),
          nn.ReLU()
        ).to(torch.bfloat16)

# Define the quantization configuration
qconfig = QLinearConfig(
    weights_config=MXConfig(elem_dtype_name="float6_e3m2", block_size=32),
    activations_config=MXConfig(elem_dtype_name="float8_e4m3", block_size=32),
)

# Quantize the model to MXFormat. Note this quantizes the model in place
quantize_linear_(model=model, qconfig=qconfig)


# Perform inference using the quantized model
x_hp = torch.randn(16, 128, dtype=torch.bfloat16)
y_mx = model(x_hp)

Examples

For more detailed examples refer the examples directory

Testing

pytest

License

torchmx is released under MIT LICENSE

Citation

If you find the torchmx library useful, please cite it in your work as below.

@software{torchmx,
  title = {torchmx: PyTorch quantization framework for OCP MX datatypes},
  authors = {Abhijit Balaji, Marios Fournarakis, TorchMX maintainers and contributors},
  url = {https://github.com/rain-neuromorphics/torchmx},
  license = {MIT License},
  month = May,
  year = {2025}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchmx-0.1.0.tar.gz (109.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchmx-0.1.0-py3-none-any.whl (33.8 kB view details)

Uploaded Python 3

File details

Details for the file torchmx-0.1.0.tar.gz.

File metadata

  • Download URL: torchmx-0.1.0.tar.gz
  • Upload date:
  • Size: 109.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.7.6

File hashes

Hashes for torchmx-0.1.0.tar.gz
Algorithm Hash digest
SHA256 8f204fa8b1a479f12056b45ef9251f82631f31ef7268d5c7bad68d395566a413
MD5 b8b0259fe3ec4eddf9779eb062201796
BLAKE2b-256 b5b45d259968b6d7caa318ae361a704df643b9eee1e032d30a1d9c9c9d6c08a3

See more details on using hashes here.

File details

Details for the file torchmx-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: torchmx-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 33.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.7.6

File hashes

Hashes for torchmx-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 e324e808fe19879761a3546c71ee72c4872dd4b121fc9b48d4f2de433b4ce2d1
MD5 8135cfcb3aa5e82270755d9ffe71d42c
BLAKE2b-256 e106280838d2d1acee210c27366bb66d8ff414fdd6f27d6b51a3691eb90bfc13

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page