Skip to main content

quantization utility modules to bridge torch fx and PT2E quantized models, as well as ONNX and others, inspired by methods in mmdeploy, without the outdated dependencies and some features not found in it.

Project description

quantizeutils

Quantization utility modules I used on my About Quantization guide.

Installation

# @ shell

pip install quantizeutils

# or

poetry add quantizeutils

Usage

Pre and Post Process FX traced models before QAT

  • quantizeutils.fx.utils.pre_procecss.propagate_split_share_qparams_pre_process()
    • torch.fx.trace() produces weirdly shared quantization parameters when torch.split() is present in the graph. This function fixes that.
import torch
from torch.ao.quantization.qconfig_mapping import QConfigMapping
from torch.ao.quantization.qconfig import QConfig
from torch.ao.quantization.observer import MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver
from torch.ao.quantization.fx.tracer import QuantizationTracer
from torch.fx import GraphModule
from torch.ao.quantization.fx import prepare
from torch.ao.quantization.backend_config import get_native_backend_config

from quantizeutils.fx.utils.pre_process import propagate_split_share_qparams_pre_process

class ExampleModel(torch.nn.Module):
    '''
    Expects mnist input of shape (batch,1,28,28)
    '''
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(1,30,1,1)
        self.spconv1 = torch.nn.Conv2d(15,5,1,1)
        self.spconv2 = torch.nn.Conv2d(15,5,1,1)
        # self.pool = torch.nn.AdaptiveMaxPool2d((1,1))
        # not supported once quantized, so replace with manual MaxPool2d
        self.pool = torch.nn.MaxPool2d(28,28)
        self.fc = torch.nn.Linear(10, 10)
    def forward(self,x):
        x = self.conv(x)
        y,z = torch.split(x,2)
        y = self.spconv1(y)
        z = self.spconv2(z)
        x = torch.cat([y,z], dim=1)
        x = self.pool(x)
        x = torch.squeeze(x, dim=2)
        x = torch.squeeze(x, dim=2)
        x = self.fc(x)
        return x


model = ExampleModel()

# define the qconfig_mapping
qconfig_mapping = QConfigMapping().set_global(
        QConfig(
            activation=MovingAverageMinMaxObserver.with_args(
                dtype=torch.quint8,
                qscheme=torch.per_tensor_affine,
                ),
            weight=MovingAveragePerChannelMinMaxObserver.with_args(
                dtype=torch.qint8,
                qscheme=torch.per_channel_symmetric,
                ),
            )
        )

# FX trace the model
tracer = QuantizationTracer(skipped_module_names=[], skipped_module_classes=[])
graph = tracer.trace(model)
traced_fx = GraphModule(tracer.root, graph, 'ExampleModel')
print(traced_fx)
'''
ExampleModel(
  (conv): Conv2d(1, 30, kernel_size=(1, 1), stride=(1, 1))
  (spconv1): Conv2d(15, 5, kernel_size=(1, 1), stride=(1, 1))
  (spconv2): Conv2d(15, 5, kernel_size=(1, 1), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=28, stride=28, padding=0, dilation=1, ceil_mode=False)
  (fc): Linear(in_features=10, out_features=10, bias=True)
)



def forward(self, x):
    conv = self.conv(x);  x = None
    split = torch.functional.split(conv, 2, dim = 0);  conv = None
    getitem = split[0]
    getitem_1 = split[1];  split = None
    spconv1 = self.spconv1(getitem);  getitem = None
    spconv2 = self.spconv2(getitem_1);  getitem_1 = None
    cat = torch.cat([spconv1, spconv2], dim = 1);  spconv1 = spconv2 = None
    pool = self.pool(cat);  cat = None
    squeeze = torch.squeeze(pool, dim = 2);  pool = None
    squeeze_1 = torch.squeeze(squeeze, dim = 2);  squeeze = None
    fc = self.fc(squeeze_1);  squeeze_1 = None
    return fc

# To see more debug info, please use `graph_module.print_readable()`
'''

# FX prepare the quantization nodes
example_inputs = (torch.randn(1,1,28,28),)
backend_config = get_native_backend_config()
prepared_fx = prepare(
    traced_fx,
    qconfig_mapping=qconfig_mapping,
    node_name_to_scope=tracer.node_name_to_scope,
    is_qat=True, # convenient even if not QAT
    example_inputs=example_inputs,
    backend_config=backend_config,
    )
print(prepared_fx)
'''
GraphModule(
  (activation_post_process_0): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  (conv): Conv2d(
    1, 30, kernel_size=(1, 1), stride=(1, 1)
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
  )
  (activation_post_process_1): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  (activation_post_process_2): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  (activation_post_process_4): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  (spconv1): Conv2d(
    15, 5, kernel_size=(1, 1), stride=(1, 1)
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
  )
  (activation_post_process_3): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  (spconv2): Conv2d(
    15, 5, kernel_size=(1, 1), stride=(1, 1)
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
  )
  (activation_post_process_5): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  (activation_post_process_6): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  (pool): MaxPool2d(kernel_size=28, stride=28, padding=0, dilation=1, ceil_mode=False)
  (activation_post_process_7): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  (activation_post_process_8): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  (activation_post_process_9): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  (fc): Linear(
    in_features=10, out_features=10, bias=True
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
  )
  (activation_post_process_10): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
)



def forward(self, x):
    activation_post_process_0 = self.activation_post_process_0(x);  x = None
    conv = self.conv(activation_post_process_0);  activation_post_process_0 = None
    activation_post_process_1 = self.activation_post_process_1(conv);  conv = None
    split = torch.functional.split(activation_post_process_1, 2, dim = 0);  activation_post_process_1 = None
    getitem = split[0]
    activation_post_process_2 = self.activation_post_process_2(getitem);  getitem = None
    getitem_1 = split[1];  split = None
    activation_post_process_4 = self.activation_post_process_4(getitem_1);  getitem_1 = None
    spconv1 = self.spconv1(activation_post_process_2);  activation_post_process_2 = None
    activation_post_process_3 = self.activation_post_process_3(spconv1);  spconv1 = None
    spconv2 = self.spconv2(activation_post_process_4);  activation_post_process_4 = None
    activation_post_process_5 = self.activation_post_process_5(spconv2);  spconv2 = None
    cat = torch.cat([activation_post_process_3, activation_post_process_5], dim = 1);  activation_post_process_3 = activation_post_process_5 = None
    activation_post_process_6 = self.activation_post_process_6(cat);  cat = None
    pool = self.pool(activation_post_process_6);  activation_post_process_6 = None
    activation_post_process_7 = self.activation_post_process_7(pool);  pool = None
    squeeze = torch.squeeze(activation_post_process_7, dim = 2);  activation_post_process_7 = None
    activation_post_process_8 = self.activation_post_process_8(squeeze);  squeeze = None
    squeeze_1 = torch.squeeze(activation_post_process_8, dim = 2);  activation_post_process_8 = None
    activation_post_process_9 = self.activation_post_process_9(squeeze_1);  squeeze_1 = None
    fc = self.fc(activation_post_process_9);  activation_post_process_9 = None
    activation_post_process_10 = self.activation_post_process_10(fc);  fc = None
    return activation_post_process_10

# To see more debug info, please use `graph_module.print_readable()`
'''

propagate_split_share_qparams_pre_process(prepared_fx, backend_config)
print(prepared_fx)
'''
GraphModule(
  (activation_post_process_0): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  (conv): Conv2d(
    1, 30, kernel_size=(1, 1), stride=(1, 1)
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
  )
  (activation_post_process_1): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  (spconv1): Conv2d(
    15, 5, kernel_size=(1, 1), stride=(1, 1)
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
  )
  (activation_post_process_3): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  (spconv2): Conv2d(
    15, 5, kernel_size=(1, 1), stride=(1, 1)
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
  )
  (activation_post_process_5): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  (activation_post_process_6): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  (pool): MaxPool2d(kernel_size=28, stride=28, padding=0, dilation=1, ceil_mode=False)
  (activation_post_process_7): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  (activation_post_process_8): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  (activation_post_process_9): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  (fc): Linear(
    in_features=10, out_features=10, bias=True
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
  )
  (activation_post_process_10): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
)



def forward(self, x):
    activation_post_process_0 = self.activation_post_process_0(x);  x = None
    conv = self.conv(activation_post_process_0);  activation_post_process_0 = None
    activation_post_process_1 = self.activation_post_process_1(conv);  conv = None
    split = torch.functional.split(activation_post_process_1, 2, dim = 0);  activation_post_process_1 = None
    getitem = split[0]
    activation_post_process_2 = self.activation_post_process_1(getitem);  getitem = None
    getitem_1 = split[1];  split = None
    activation_post_process_4 = self.activation_post_process_1(getitem_1);  getitem_1 = None
    spconv1 = self.spconv1(activation_post_process_2);  activation_post_process_2 = None
    activation_post_process_3 = self.activation_post_process_3(spconv1);  spconv1 = None
    spconv2 = self.spconv2(activation_post_process_4);  activation_post_process_4 = None
    activation_post_process_5 = self.activation_post_process_5(spconv2);  spconv2 = None
    cat = torch.cat([activation_post_process_3, activation_post_process_5], dim = 1);  activation_post_process_3 = activation_post_process_5 = None
    activation_post_process_6 = self.activation_post_process_6(cat);  cat = None
    pool = self.pool(activation_post_process_6);  activation_post_process_6 = None
    activation_post_process_7 = self.activation_post_process_7(pool);  pool = None
    squeeze = torch.squeeze(activation_post_process_7, dim = 2);  activation_post_process_7 = None
    activation_post_process_8 = self.activation_post_process_7(squeeze);  squeeze = None
    squeeze_1 = torch.squeeze(activation_post_process_8, dim = 2);  activation_post_process_8 = None
    activation_post_process_9 = self.activation_post_process_7(squeeze_1);  squeeze_1 = None
    fc = self.fc(activation_post_process_9);  activation_post_process_9 = None
    activation_post_process_10 = self.activation_post_process_10(fc);  fc = None
    return activation_post_process_10

# To see more debug info, please use `graph_module.print_readable()`
'''
  • quantizeutils.fx.utils.pre_procecss.relu_clamp_backend_config_unshare_observers()
    • ReLU and torch.clamp use shared observers in the torch native backend config (default). This expands the quantization min and max unnecessarily keeping, for example, min values below 0 on ReLU nodes and wasting quantization scaling space that is not needed. This function fixes that if applied before FX tracing.
from typing import Any
import torch
from torch.ao.quantization.qconfig_mapping import QConfigMapping
from torch.ao.quantization.qconfig import QConfig
from torch.ao.quantization.observer import MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver
from torch.ao.quantization.fx.tracer import QuantizationTracer
from torch.fx import GraphModule
from torch.ao.quantization.fx import prepare
from torch.ao.quantization.backend_config import get_native_backend_config

from quantizeutils.fx.utils.pre_process import relu_clamp_backend_config_unshare_observers

class ExampleModel(torch.nn.Module):
    '''
    Expects mnist input of shape (batch,1,28,28)
    '''
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(1,32,1,1)
        self.conv2 = torch.nn.Conv2d(32,10,1,1)
        # self.pool = torch.nn.AdaptiveMaxPool2d((1,1))
        # not supported once quantized, so replace with manual MaxPool2d
        self.pool = torch.nn.MaxPool2d(28,28)
        self.fc = torch.nn.Linear(10, 10)
    def forward(self,x):
        x = self.conv1(x)
        x = torch.nn.functional.relu(x)
        x = self.conv2(x)
        x = torch.nn.functional.relu(x)
        x = self.pool(x)
        x = torch.squeeze(x, dim=2)
        x = torch.squeeze(x, dim=2)
        x = self.fc(x)
        return x

model = ExampleModel()

# define the qconfig_mapping
qconfig_mapping = QConfigMapping().set_global(
        QConfig(
            activation=MovingAverageMinMaxObserver.with_args(
                dtype=torch.quint8,
                qscheme=torch.per_tensor_affine,
                ),
            weight=MovingAveragePerChannelMinMaxObserver.with_args(
                dtype=torch.qint8,
                qscheme=torch.per_channel_symmetric,
                ),
            )
        )

# FX trace the model
tracer = QuantizationTracer(skipped_module_names=[], skipped_module_classes=[])
graph = tracer.trace(model)
traced_fx = GraphModule(tracer.root, graph, 'ExampleModel')
print(traced_fx)
'''
ExampleModel(
  (conv1): Conv2d(1, 32, kernel_size=(1, 1), stride=(1, 1))
  (conv2): Conv2d(32, 10, kernel_size=(1, 1), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=28, stride=28, padding=0, dilation=1, ceil_mode=False)
  (fc): Linear(in_features=10, out_features=10, bias=True)
)



def forward(self, x):
    conv1 = self.conv1(x);  x = None
    relu = torch.nn.functional.relu(conv1, inplace = False);  conv1 = None
    conv2 = self.conv2(relu);  relu = None
    relu_1 = torch.nn.functional.relu(conv2, inplace = False);  conv2 = None
    pool = self.pool(relu_1);  relu_1 = None
    squeeze = torch.squeeze(pool, dim = 2);  pool = None
    squeeze_1 = torch.squeeze(squeeze, dim = 2);  squeeze = None
    fc = self.fc(squeeze_1);  squeeze_1 = None
    return fc

# To see more debug info, please use `graph_module.print_readable()`
'''

# Don't fuse just to prove what happens to nodes that can't fuse with ReLU
# FX prepare the quantization nodes
example_inputs = (torch.randn(1,1,28,28),)
backend_config = get_native_backend_config()
backend_config = relu_clamp_backend_config_unshare_observers(backend_config)

prepared_fx = prepare(
    traced_fx,
    qconfig_mapping=qconfig_mapping,
    node_name_to_scope=tracer.node_name_to_scope,
    is_qat=True, # convenient even if not QAT
    example_inputs=example_inputs,
    backend_config=backend_config,
    )
# pass some input to see the observed scales
prepared_fx(example_inputs[0])
print(prepared_fx)
'''
GraphModule(
  (activation_post_process_0): MovingAverageMinMaxObserver(min_val=-2.873756170272827, max_val=3.512624740600586)
  (conv1): Conv2d(
    1, 32, kernel_size=(1, 1), stride=(1, 1)
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(
      min_val=tensor([-0.2244,  0.1366,  0.5535,  0.4781,  0.9987, -0.8585,  0.2149,  0.2896,
              -0.3051, -0.2861,  0.4546,  0.6375, -0.9563, -0.2443,  0.9397,  0.4525,
              -0.8703,  0.0118,  0.7989,  0.4656,  0.8642, -0.8372, -0.6900, -0.2179,
              -0.9575,  0.1994,  0.9602,  0.8782,  0.1776, -0.9443, -0.2989,  0.3896]), max_val=tensor([-0.2244,  0.1366,  0.5535,  0.4781,  0.9987, -0.8585,  0.2149,  0.2896,
              -0.3051, -0.2861,  0.4546,  0.6375, -0.9563, -0.2443,  0.9397,  0.4525,
              -0.8703,  0.0118,  0.7989,  0.4656,  0.8642, -0.8372, -0.6900, -0.2179,
              -0.9575,  0.1994,  0.9602,  0.8782,  0.1776, -0.9443, -0.2989,  0.3896])
    )
  )
  (activation_post_process_1): MovingAverageMinMaxObserver(min_val=-3.461930274963379, max_val=3.918088436126709)
  (activation_post_process_2): MovingAverageMinMaxObserver(min_val=0.0, max_val=3.918088436126709)
  (conv2): Conv2d(
    32, 10, kernel_size=(1, 1), stride=(1, 1)
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(
      min_val=tensor([-0.1602, -0.1757, -0.1724, -0.1654, -0.1460, -0.1729, -0.1722, -0.1635,
              -0.1721, -0.1758]), max_val=tensor([0.1669, 0.1637, 0.1765, 0.1631, 0.1765, 0.1700, 0.1761, 0.1572, 0.1726,
              0.1767])
    )
  )
  (activation_post_process_3): MovingAverageMinMaxObserver(min_val=-2.4537484645843506, max_val=1.3463588953018188)
  (activation_post_process_4): MovingAverageMinMaxObserver(min_val=0.0, max_val=1.3463588953018188)
  (pool): MaxPool2d(kernel_size=28, stride=28, padding=0, dilation=1, ceil_mode=False)
  (activation_post_process_5): MovingAverageMinMaxObserver(min_val=0.0, max_val=1.3463588953018188)
  (activation_post_process_6): MovingAverageMinMaxObserver(min_val=0.0, max_val=1.3463588953018188)
  (activation_post_process_7): MovingAverageMinMaxObserver(min_val=0.0, max_val=1.3463588953018188)
  (fc): Linear(
    in_features=10, out_features=10, bias=True
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(
      min_val=tensor([-0.2888, -0.2640, -0.3119, -0.3112, -0.2534, -0.3156, -0.2953, -0.3095,
              -0.2894, -0.2766]), max_val=tensor([0.2896, 0.2684, 0.2644, 0.2437, 0.3056, 0.3141, 0.3016, 0.2660, 0.2929,
              0.3111])
    )
  )
  (activation_post_process_8): MovingAverageMinMaxObserver(min_val=-0.6559169888496399, max_val=1.2133853435516357)
)



def forward(self, x):
    activation_post_process_0 = self.activation_post_process_0(x);  x = None
    conv1 = self.conv1(activation_post_process_0);  activation_post_process_0 = None
    activation_post_process_1 = self.activation_post_process_1(conv1);  conv1 = None
    relu = torch.nn.functional.relu(activation_post_process_1, inplace = False);  activation_post_process_1 = None
    activation_post_process_2 = self.activation_post_process_2(relu);  relu = None
    conv2 = self.conv2(activation_post_process_2);  activation_post_process_2 = None
    activation_post_process_3 = self.activation_post_process_3(conv2);  conv2 = None
    relu_1 = torch.nn.functional.relu(activation_post_process_3, inplace = False);  activation_post_process_3 = None
    activation_post_process_4 = self.activation_post_process_4(relu_1);  relu_1 = None
    pool = self.pool(activation_post_process_4);  activation_post_process_4 = None
    activation_post_process_5 = self.activation_post_process_5(pool);  pool = None
    squeeze = torch.squeeze(activation_post_process_5, dim = 2);  activation_post_process_5 = None
    activation_post_process_6 = self.activation_post_process_6(squeeze);  squeeze = None
    squeeze_1 = torch.squeeze(activation_post_process_6, dim = 2);  activation_post_process_6 = None
    activation_post_process_7 = self.activation_post_process_7(squeeze_1);  squeeze_1 = None
    fc = self.fc(activation_post_process_7);  activation_post_process_7 = None
    activation_post_process_8 = self.activation_post_process_8(fc);  fc = None
    return activation_post_process_8

# To see more debug info, please use `graph_module.print_readable()`
'''
  • quantizeutils.fx.utils.post_process.fuse_qat_bn_post_process()
    • Prepares QAT unfused nodes (for example batch normalization) before exporting to ONNX
from quantizeutils.fx.utils.post_process import fuse_qat_bn_post_process

print(prepared_model)
'''
ExampleModel(
  (quant): QuantStub(
    (activation_post_process): MovingAverageMinMaxObserver(min_val=-0.4242129623889923, max_val=2.821486711502075)
  )
  (dequant): DeQuantStub()
  (conv1): ConvBnReLU2d(
    1, 32, kernel_size=(1, 1), stride=(1, 1)
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(
      min_val=tensor([ 0.9805,  0.9878, -0.9778, -0.9910, -0.9607,  1.0020,  0.9854, -0.9932,
               0.9920, -1.0192,  0.9573,  0.9959,  1.0661, -0.9587, -0.4509,  0.8841,
               0.9185,  0.9696, -0.9722, -1.0099,  1.0207, -1.0131,  0.9228,  0.9731,
              -0.8032, -0.9803, -0.9691,  1.0209, -0.9520,  1.0132,  1.0179, -1.0628],
             device='cuda:0'), max_val=tensor([ 0.9805,  0.9878, -0.9778, -0.9910, -0.9607,  1.0020,  0.9854, -0.9932,
               0.9920, -1.0192,  0.9573,  0.9959,  1.0661, -0.9587, -0.4509,  0.8841,
               0.9185,  0.9696, -0.9722, -1.0099,  1.0207, -1.0131,  0.9228,  0.9731,
              -0.8032, -0.9803, -0.9691,  1.0209, -0.9520,  1.0132,  1.0179, -1.0628],
             device='cuda:0')
    )
    (activation_post_process): MovingAverageMinMaxObserver(min_val=0.0, max_val=3.486288547515869)
  )
  (bn1): Identity()
  (act1): Identity()
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): ConvBnReLU2d(
    32, 64, kernel_size=(1, 1), stride=(1, 1)
    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(
      min_val=tensor([-2.1960, -0.2812, -1.7402, -0.8100, -3.3399, -1.0657, -0.1177, -0.6993,
              -0.2671, -4.9234, -0.5415, -3.2926, -0.1593, -0.3427, -0.4231, -3.2403,
              -0.7386, -2.6828, -5.1171, -3.6965, -0.2139, -9.6442, -0.3108, -2.1643,
              -0.7892, -2.2819, -2.0034, -0.4834, -0.3995, -4.6650, -1.5611, -1.3696,
              -3.9522, -0.3022, -1.5632, -0.4557, -5.8931, -0.4400, -1.2626, -0.5098,
              -3.3187, -0.3899, -0.4554, -2.8338, -0.4487, -0.2008, -1.1349, -5.3991,
              -4.4046, -0.4110, -1.2552, -0.5631, -0.3380, -2.7315, -2.2920, -2.1396,
              -0.4084, -6.2974, -1.1824, -0.1679, -2.1181, -0.8331, -1.1392, -5.4736],
             device='cuda:0'), max_val=tensor([2.1027, 0.2706, 2.1971, 0.7181, 2.7690, 0.9632, 0.1437, 0.6662, 0.2961,
              3.6605, 0.5448, 2.8305, 0.1597, 0.3670, 0.4135, 4.4315, 0.7591, 3.2193,
              4.7957, 4.5160, 0.2183, 8.5344, 0.3707, 2.4349, 1.2237, 2.4649, 2.5588,
              0.4111, 0.5481, 5.3878, 1.8818, 0.9792, 3.3811, 0.3097, 0.9306, 0.5589,
              5.2913, 0.3896, 0.8146, 0.8410, 2.5406, 0.2423, 0.4662, 3.2510, 0.4216,
              0.2331, 0.7295, 5.5472, 3.2039, 0.3767, 1.2058, 0.4847, 0.2246, 2.3556,
              2.5913, 1.7738, 0.4303, 6.8638, 0.9290, 0.2386, 2.2515, 1.1236, 0.7875,
              4.1441], device='cuda:0')
    )
    (activation_post_process): MovingAverageMinMaxObserver(min_val=0.0, max_val=7.873152732849121)
  )
  (bn2): Identity()
  (act2): Identity()
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fl1): Flatten(start_dim=1, end_dim=-1)
  (fc1): LinearReLU(
    in_features=3136, out_features=10, bias=True
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(
      min_val=tensor([-0.1110, -0.1682, -0.1074, -0.0716, -0.1544, -0.1686, -0.1512, -0.2297,
              -0.1243, -0.1656], device='cuda:0'), max_val=tensor([0.1568, 0.1229, 0.1314, 0.0929, 0.1936, 0.1594, 0.1684, 0.1468, 0.1751,
              0.1768], device='cuda:0')
    )
    (activation_post_process): MovingAverageMinMaxObserver(min_val=0.0, max_val=52.048946380615234)
  )
  (fc1act): Identity()
  (fc2): Linear(
    in_features=10, out_features=10, bias=True
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(
      min_val=tensor([-0.4732, -0.7787, -0.3168, -0.1518, -0.6355, -0.3758, -0.5703, -0.3268,
              -0.3168, -0.5046], device='cuda:0'), max_val=tensor([0.3203, 0.4310, 0.3799, 0.3980, 0.3096, 0.2491, 0.4260, 0.2836, 0.3210,
              0.3214], device='cuda:0')
    )
    (activation_post_process): MovingAverageMinMaxObserver(min_val=-21.06275177001953, max_val=15.43133544921875)
  )
)
'''

qconfig = QConfig(
    activation=MovingAverageMinMaxObserver.with_args(
        dtype=torch.quint8,
        qscheme=torch.per_tensor_affine,
        ),
    weight=MovingAveragePerChannelMinMaxObserver.with_args(
        dtype=torch.qint8,
        qscheme=torch.per_channel_symmetric,
        ),
    )
device='cuda:0'

fuse_qat_bn_post_process(
    prepared_model,
    qconfig,
    device,
    update_weight_with_fakequant=False,
    keep_w_fake_quant=True)
print(prepared_model)
'''
ExampleModel(
  (quant): QuantStub(
    (activation_post_process): MovingAverageMinMaxObserver(min_val=-0.4242129623889923, max_val=2.821486711502075)
  )
  (dequant): DeQuantStub()
  (conv1): ConvReLU2d(
    1, 32, kernel_size=(1, 1), stride=(1, 1)
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(
      min_val=tensor([ 0.9773,  0.9846, -0.9746, -0.9878, -0.9576,  0.9987,  0.9822, -0.9899,
               0.9888, -1.0159,  0.9542,  0.9927,  1.0627, -0.9556, -0.4495,  0.8815,
               0.9156,  0.9664, -0.9691, -1.0067,  1.0174, -1.0098,  0.9199,  0.9699,
              -0.8006, -0.9771, -0.9660,  1.0176, -0.9489,  1.0099,  1.0146, -1.0593],
             device='cuda:0'), max_val=tensor([ 0.9773,  0.9846, -0.9746, -0.9878, -0.9576,  0.9987,  0.9822, -0.9899,
               0.9888, -1.0159,  0.9542,  0.9927,  1.0627, -0.9556, -0.4495,  0.8815,
               0.9156,  0.9664, -0.9691, -1.0067,  1.0174, -1.0098,  0.9199,  0.9699,
              -0.8006, -0.9771, -0.9660,  1.0176, -0.9489,  1.0099,  1.0146, -1.0593],
             device='cuda:0')
    )
  )
  (bn1): Identity()
  (act1): Identity()
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): ConvReLU2d(
    32, 64, kernel_size=(1, 1), stride=(1, 1)
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(
      min_val=tensor([-2.2004, -0.2815, -1.7259, -0.8108, -3.3225, -1.0663, -0.1178, -0.6998,
              -0.2673, -4.9116, -0.5416, -3.3015, -0.1594, -0.3430, -0.4228, -3.2384,
              -0.7388, -2.6793, -5.0133, -3.6836, -0.2141, -9.6123, -0.3111, -2.1644,
              -0.7847, -2.2629, -1.9944, -0.4836, -0.3984, -4.6541, -1.5561, -1.3597,
              -3.9317, -0.3023, -1.5552, -0.4559, -5.8850, -0.4401, -1.2523, -0.5099,
              -3.3166, -0.3899, -0.4554, -2.8402, -0.4490, -0.2010, -1.1246, -5.3832,
              -4.3736, -0.4115, -1.2560, -0.5640, -0.3383, -2.7325, -2.2894, -2.1422,
              -0.4086, -6.1985, -1.1782, -0.1680, -2.1212, -0.8303, -1.1334, -5.4791],
             device='cuda:0'), max_val=tensor([2.1069, 0.2708, 2.1790, 0.7187, 2.7546, 0.9638, 0.1438, 0.6667, 0.2964,
              3.6516, 0.5449, 2.8382, 0.1598, 0.3673, 0.4133, 4.4289, 0.7594, 3.2151,
              4.6985, 4.5002, 0.2185, 8.5062, 0.3711, 2.4350, 1.2167, 2.4443, 2.5474,
              0.4113, 0.5465, 5.3751, 1.8757, 0.9722, 3.3635, 0.3099, 0.9258, 0.5592,
              5.2839, 0.3897, 0.8079, 0.8413, 2.5389, 0.2423, 0.4662, 3.2584, 0.4219,
              0.2332, 0.7228, 5.5310, 3.1814, 0.3771, 1.2065, 0.4855, 0.2248, 2.3564,
              2.5883, 1.7760, 0.4306, 6.7560, 0.9257, 0.2388, 2.2547, 1.1199, 0.7835,
              4.1483], device='cuda:0')
    )
  )
  (bn2): Identity()
  (act2): Identity()
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fl1): Flatten(start_dim=1, end_dim=-1)
  (fc1): LinearReLU(
    in_features=3136, out_features=10, bias=True
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(
      min_val=tensor([-0.1110, -0.1682, -0.1074, -0.0716, -0.1544, -0.1686, -0.1512, -0.2297,
              -0.1243, -0.1656], device='cuda:0'), max_val=tensor([0.1568, 0.1229, 0.1314, 0.0929, 0.1936, 0.1594, 0.1684, 0.1468, 0.1751,
              0.1768], device='cuda:0')
    )
  )
  (fc1act): Identity()
  (fc2): Linear(
    in_features=10, out_features=10, bias=True
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(
      min_val=tensor([-0.4732, -0.7787, -0.3168, -0.1518, -0.6355, -0.3758, -0.5703, -0.3268,
              -0.3168, -0.5046], device='cuda:0'), max_val=tensor([0.3203, 0.4310, 0.3799, 0.3980, 0.3096, 0.2491, 0.4260, 0.2836, 0.3210,
              0.3214], device='cuda:0')
    )
  )
)
'''
  • quantizeutils.fx.utils.post_process.merge_relu_clamp_to_qparams_post_process
    • Some modules like Conv+ReLU will fuse automatically in the native backend but remain unfused if exported to ONNX or other backends. This function merges the ReLU and torch.clamp node activations to the previous node as part of their q_min and q_max, instead of relying on a secondary node.
from quantizeutils.fx.utils.post_process import merge_relu_clamp_to_qparams_post_process

merge_relu_clamp_to_qparams_post_process(prepared_fx)
print(prepared_fx)
'''
GraphModule(
  (activation_post_process_0): MovingAverageMinMaxObserver(min_val=-2.873756170272827, max_val=3.512624740600586)
  (conv1): Conv2d(
    1, 32, kernel_size=(1, 1), stride=(1, 1)
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(
      min_val=tensor([-0.2244,  0.1366,  0.5535,  0.4781,  0.9987, -0.8585,  0.2149,  0.2896,
              -0.3051, -0.2861,  0.4546,  0.6375, -0.9563, -0.2443,  0.9397,  0.4525,
              -0.8703,  0.0118,  0.7989,  0.4656,  0.8642, -0.8372, -0.6900, -0.2179,
              -0.9575,  0.1994,  0.9602,  0.8782,  0.1776, -0.9443, -0.2989,  0.3896]), max_val=tensor([-0.2244,  0.1366,  0.5535,  0.4781,  0.9987, -0.8585,  0.2149,  0.2896,
              -0.3051, -0.2861,  0.4546,  0.6375, -0.9563, -0.2443,  0.9397,  0.4525,
              -0.8703,  0.0118,  0.7989,  0.4656,  0.8642, -0.8372, -0.6900, -0.2179,
              -0.9575,  0.1994,  0.9602,  0.8782,  0.1776, -0.9443, -0.2989,  0.3896])
    )
  )
  (activation_post_process_2): MovingAverageMinMaxObserver(min_val=0.0, max_val=3.918088436126709)
  (conv2): Conv2d(
    32, 10, kernel_size=(1, 1), stride=(1, 1)
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(
      min_val=tensor([-0.1602, -0.1757, -0.1724, -0.1654, -0.1460, -0.1729, -0.1722, -0.1635,
              -0.1721, -0.1758]), max_val=tensor([0.1669, 0.1637, 0.1765, 0.1631, 0.1765, 0.1700, 0.1761, 0.1572, 0.1726,
              0.1767])
    )
  )
  (activation_post_process_4): MovingAverageMinMaxObserver(min_val=0.0, max_val=1.3463588953018188)
  (pool): MaxPool2d(kernel_size=28, stride=28, padding=0, dilation=1, ceil_mode=False)
  (activation_post_process_5): MovingAverageMinMaxObserver(min_val=0.0, max_val=1.3463588953018188)
  (activation_post_process_6): MovingAverageMinMaxObserver(min_val=0.0, max_val=1.3463588953018188)
  (activation_post_process_7): MovingAverageMinMaxObserver(min_val=0.0, max_val=1.3463588953018188)
  (fc): Linear(
    in_features=10, out_features=10, bias=True
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(
      min_val=tensor([-0.2888, -0.2640, -0.3119, -0.3112, -0.2534, -0.3156, -0.2953, -0.3095,
              -0.2894, -0.2766]), max_val=tensor([0.2896, 0.2684, 0.2644, 0.2437, 0.3056, 0.3141, 0.3016, 0.2660, 0.2929,
              0.3111])
    )
  )
  (activation_post_process_8): MovingAverageMinMaxObserver(min_val=-0.6559169888496399, max_val=1.2133853435516357)
)



def forward(self, x):
    activation_post_process_0 = self.activation_post_process_0(x);  x = None
    conv1 = self.conv1(activation_post_process_0);  activation_post_process_0 = None
    activation_post_process_2 = self.activation_post_process_2(conv1);  conv1 = None
    conv2 = self.conv2(activation_post_process_2);  activation_post_process_2 = None
    activation_post_process_4 = self.activation_post_process_4(conv2);  conv2 = None
    pool = self.pool(activation_post_process_4);  activation_post_process_4 = None
    activation_post_process_5 = self.activation_post_process_5(pool);  pool = None
    squeeze = torch.squeeze(activation_post_process_5, dim = 2);  activation_post_process_5 = None
    activation_post_process_6 = self.activation_post_process_6(squeeze);  squeeze = None
    squeeze_1 = torch.squeeze(activation_post_process_6, dim = 2);  activation_post_process_6 = None
    activation_post_process_7 = self.activation_post_process_7(squeeze_1);  squeeze_1 = None
    fc = self.fc(activation_post_process_7);  activation_post_process_7 = None
    activation_post_process_8 = self.activation_post_process_8(fc);  fc = None
    return activation_post_process_8

# To see more debug info, please use `graph_module.print_readable()`
'''

FX Backend for AIEdgeTorch export

AIEdgeTorch is a powerful (but still volatile) tool to convert torch models to tensorflow through PT2E. Since some models are currently only quantized with FX graphs, I thought to write an FX backend configuration to potentially convert FX models to ai_edge_torch exportable models. More on my About Quantization guide.

quantizeutils.fx.backend_config.ai_edge_backend

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

quantizeutils-0.1.1.tar.gz (30.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

quantizeutils-0.1.1-py3-none-any.whl (28.2 kB view details)

Uploaded Python 3

File details

Details for the file quantizeutils-0.1.1.tar.gz.

File metadata

  • Download URL: quantizeutils-0.1.1.tar.gz
  • Upload date:
  • Size: 30.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.1.4 CPython/3.12.11 Linux/5.15.153.1-microsoft-standard-WSL2

File hashes

Hashes for quantizeutils-0.1.1.tar.gz
Algorithm Hash digest
SHA256 37f3ddec21a96e8aa6cf520be48426fc1ef64ca189e2052abbc1e1cdd4172430
MD5 f0a1d8f56f7e9eab0ee28352a4bdf26e
BLAKE2b-256 692d4757799fe193dc9200ac3ab67840086a92ea5ee9d9caea3d98706da93aad

See more details on using hashes here.

File details

Details for the file quantizeutils-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: quantizeutils-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 28.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.1.4 CPython/3.12.11 Linux/5.15.153.1-microsoft-standard-WSL2

File hashes

Hashes for quantizeutils-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 f84bf6ec42c9eb6034c16831fe8070e79142c65da6d732071e8155c44c7c5a7d
MD5 a70b2c5ca58fc6f34e7266944a65ffdd
BLAKE2b-256 fd4faa6a5dd25e416113c76b64f931945c38577195c980e1b1c4bff7b7cf8166

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page